Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fa9da1a4
Commit
fa9da1a4
authored
Jun 19, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
4c105089
457308e3
Changes
1000
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
107 additions
and
187 deletions
+107
-187
example/12_reduce/reduce_example_common.hpp
example/12_reduce/reduce_example_common.hpp
+1
-1
example/12_reduce/reduce_multiblock_atomic_add.cpp
example/12_reduce/reduce_multiblock_atomic_add.cpp
+1
-1
example/12_reduce/reduce_multiblock_atomic_add_impl.hpp
example/12_reduce/reduce_multiblock_atomic_add_impl.hpp
+1
-1
example/13_pool2d_fwd/pool2d_fwd_common.hpp
example/13_pool2d_fwd/pool2d_fwd_common.hpp
+41
-133
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
+5
-6
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
+5
-6
example/14_gemm_quantization/CMakeLists.txt
example/14_gemm_quantization/CMakeLists.txt
+9
-2
example/14_gemm_quantization/gemm_dl_quantization_int8.cpp
example/14_gemm_quantization/gemm_dl_quantization_int8.cpp
+1
-1
example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp
...emm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp
+1
-1
example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp
example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp
+1
-1
example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp
+1
-1
example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp
+1
-1
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+1
-1
example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp
+1
-1
example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp
example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp
+1
-1
example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp
example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp
+1
-1
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
+1
-1
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
+32
-25
example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp
...d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp
+1
-1
example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp
...emm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp
+1
-1
No files found.
Too many changes to show.
To preserve performance only
1000 of 1000+
files are displayed.
Plain diff
Email patch
example/12_reduce/reduce_example_common.hpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
example/12_reduce/reduce_multiblock_atomic_add.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <initializer_list>
...
...
example/12_reduce/reduce_multiblock_atomic_add_impl.hpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
example/13_pool2d_fwd/pool2d_fwd_common.hpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -17,115 +17,11 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp"
template
<
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
IndexDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
PropagateNan
,
bool
OutputIndex
>
static
void
pool_host_verify
(
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
Tensor
<
IndexDataType
>&
out_indices
,
const
std
::
array
<
ck
::
index_t
,
2
>&
window_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
2
>&
window_strides
,
const
std
::
array
<
ck
::
index_t
,
2
>&
in_left_pads
,
const
std
::
array
<
ck
::
index_t
,
2
>&
/*in_right_pads*/
)
{
const
int32_t
reduceLength
=
window_spatial_lengths
[
0
]
*
window_spatial_lengths
[
1
];
using
ReduceOperation
=
typename
ck
::
reduce_binary_operator
<
ReduceOpId
>::
opType
;
auto
elementwise_ops
=
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
reduceLength
);
auto
in_elementwise_op
=
std
::
get
<
0
>
(
elementwise_ops
);
auto
acc_elementwise_op
=
std
::
get
<
1
>
(
elementwise_ops
);
if
constexpr
(
!
OutputIndex
)
{
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
ho
,
auto
wo
)
{
auto
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>();
for
(
ck
::
index_t
y
=
0
;
y
<
window_spatial_lengths
[
0
];
++
y
)
{
ck
::
index_t
hi
=
ho
*
window_strides
[
0
]
+
y
-
in_left_pads
[
0
];
for
(
ck
::
index_t
x
=
0
;
x
<
window_spatial_lengths
[
1
];
++
x
)
{
ck
::
index_t
wi
=
wo
*
window_strides
[
1
]
+
x
-
in_left_pads
[
1
];
if
(
hi
>=
0
&&
hi
<
static_cast
<
ck
::
index_t
>
(
in
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
wi
<
static_cast
<
ck
::
index_t
>
(
in
.
mDesc
.
GetLengths
()[
3
]))
{
AccDataType
currVal
=
static_cast
<
AccDataType
>
(
in
(
n
,
c
,
hi
,
wi
));
in_elementwise_op
(
currVal
,
currVal
);
Accumulation
::
Calculate
(
accuVal
,
currVal
);
}
}
}
acc_elementwise_op
(
accuVal
,
accuVal
);
out
(
n
,
c
,
ho
,
wo
)
=
accuVal
;
};
make_ParallelTensorFunctor
(
f_nchw
,
out
.
mDesc
.
GetLengths
()[
0
],
out
.
mDesc
.
GetLengths
()[
1
],
out
.
mDesc
.
GetLengths
()[
2
],
out
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
{
using
Accumulation
=
ck
::
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
;
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
ho
,
auto
wo
)
{
auto
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>();
IndexDataType
accuIndex
=
0
;
for
(
ck
::
index_t
y
=
0
;
y
<
window_spatial_lengths
[
0
];
++
y
)
{
ck
::
index_t
hi
=
ho
*
window_strides
[
0
]
+
y
-
in_left_pads
[
0
];
for
(
ck
::
index_t
x
=
0
;
x
<
window_spatial_lengths
[
1
];
++
x
)
{
ck
::
index_t
wi
=
wo
*
window_strides
[
1
]
+
x
-
in_left_pads
[
1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
{
AccDataType
currVal
=
static_cast
<
AccDataType
>
(
in
(
n
,
c
,
hi
,
wi
));
IndexDataType
currIndex
=
y
*
window_spatial_lengths
[
1
]
+
x
;
in_elementwise_op
(
currVal
,
currVal
);
Accumulation
::
Calculate
(
accuVal
,
currVal
,
accuIndex
,
currIndex
);
}
}
}
acc_elementwise_op
(
accuVal
,
accuVal
);
out
(
n
,
c
,
ho
,
wo
)
=
accuVal
;
out_indices
(
n
,
c
,
ho
,
wo
)
=
accuIndex
;
};
make_ParallelTensorFunctor
(
f_nchw
,
out
.
mDesc
.
GetLengths
()[
0
],
out
.
mDesc
.
GetLengths
()[
1
],
out
.
mDesc
.
GetLengths
()[
2
],
out
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
};
}
template
<
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
ComputeDataType
,
typename
IndexDataType
,
typename
InLayout
,
typename
OutLayout
,
...
...
@@ -150,9 +46,10 @@ bool pool_test(bool do_verification,
{
using
DevicePoolFwdInstance
=
ck
::
tensor_operation
::
device
::
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
InDataType
,
// InDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InDataType
,
// InDataType
OutDataType
,
// OutDataType
IndexDataType
,
// IndexDataType
ComputeDataType
,
// ComputeDataType
ReduceOpId
,
OutputIndex
,
64
,
// BlockSize
...
...
@@ -165,10 +62,10 @@ bool pool_test(bool do_verification,
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
const
std
::
array
<
ck
::
index_t
,
2
>
window_spatial_lengths
{
{
Y
,
X
}
}
;
const
std
::
array
<
ck
::
index_t
,
2
>
window_strides
{
{
window_stride_h
,
window_stride_w
}
}
;
const
std
::
array
<
ck
::
index_t
,
2
>
input_left_pads
{
{
in_left_pad_h
,
in_left_pad_w
}
}
;
const
std
::
array
<
ck
::
index_t
,
2
>
input_right_pads
{
{
in_right_pad_h
,
in_right_pad_w
}
}
;
const
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
{
Y
,
X
};
const
std
::
vector
<
ck
::
index_t
>
window_strides
{
window_stride_h
,
window_stride_w
};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{
in_left_pad_h
,
in_left_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{
in_right_pad_h
,
in_right_pad_w
};
// tensor layout
auto
f_host_tensor_descriptor
=
...
...
@@ -219,14 +116,16 @@ bool pool_test(bool do_verification,
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
N
,
C
,
std
::
array
<
ck
::
index_t
,
2
>
{{
Hi
,
Wi
}},
std
::
array
<
ck
::
index_t
,
2
>
{{
Y
,
X
}},
std
::
array
<
ck
::
index_t
,
2
>
{{
Ho
,
Wo
}},
{
N
,
C
,
Hi
,
Wi
},
{
Y
,
X
},
{
N
,
C
,
Ho
,
Wo
},
{
C
*
Hi
*
Wi
,
1
,
Wi
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
window_strides
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
{
2
,
3
});
if
(
!
pool
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
...
...
@@ -252,19 +151,28 @@ bool pool_test(bool do_verification,
if
(
do_verification
)
{
pool_host_verify
<
InDataType
,
OutDataType
,
AccDataType
,
IndexDataType
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
in_n_c_hi_wi
,
out_n_c_ho_wo_host
,
out_indices_n_c_ho_wo_host
,
window_spatial_lengths
,
window_strides
,
input_left_pads
,
input_right_pads
);
using
ReferencePoolingFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferencePoolingFwd
<
4
,
2
,
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
;
auto
ref_pooling
=
ReferencePoolingFwdInstance
{};
auto
ref_pooling_invoker
=
ref_pooling
.
MakeInvoker
();
auto
ref_pooling_argument
=
ref_pooling
.
MakeArgument
(
in_n_c_hi_wi
,
out_n_c_ho_wo_host
,
out_indices_n_c_ho_wo_host
,
window_spatial_lengths
,
window_strides
,
input_left_pads
,
input_right_pads
);
ref_pooling_invoker
.
Run
(
ref_pooling_argument
);
out_device_buf
.
FromDevice
(
out_n_c_ho_wo_device
.
mData
.
data
());
...
...
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...
...
@@ -10,9 +9,9 @@
#include "pool2d_fwd_common.hpp"
using
InDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
Acc
DataType
=
float
;
using
InDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
Compute
DataType
=
float
;
using
IndexDataType
=
int32_t
;
...
...
@@ -91,7 +90,7 @@ int main(int argc, char* argv[])
bool
pass
=
pool_test
<
InDataType
,
OutDataType
,
Acc
DataType
,
Compute
DataType
,
IndexDataType
,
InLayout
,
OutLayout
,
...
...
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
...
...
@@ -10,9 +9,9 @@
#include "pool2d_fwd_common.hpp"
using
InDataType
=
float
;
using
OutDataType
=
float
;
using
Acc
DataType
=
float
;
using
InDataType
=
float
;
using
OutDataType
=
float
;
using
Compute
DataType
=
float
;
using
IndexDataType
=
int32_t
;
...
...
@@ -91,7 +90,7 @@ int main(int argc, char* argv[])
bool
pass
=
pool_test
<
InDataType
,
OutDataType
,
Acc
DataType
,
Compute
DataType
,
IndexDataType
,
InLayout
,
OutLayout
,
...
...
example/14_gemm_quantization/CMakeLists.txt
View file @
fa9da1a4
...
...
@@ -2,5 +2,12 @@
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
# xdlops
add_example_executable
(
example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp
)
add_example_executable
(
example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp
)
\ No newline at end of file
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_example_executable
(
example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp
)
add_example_executable
(
example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp
)
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/14_gemm_quantization/gemm_dl_quantization_int8.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstddef>
...
...
example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
View file @
fa9da1a4
add_custom_target
(
example_gemm_reduce_xdl
)
add_custom_target
(
example_gemm_reduce_xdl_max
)
add_custom_target
(
example_gemm_reduce_xdl_mean_meansquare
)
add_custom_target
(
example_gemm_add_add_mean_meansquare_xdl
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp
)
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_gemm_reduce_xdl
)
add_custom_target
(
example_gemm_reduce_xdl_max
)
add_custom_target
(
example_gemm_reduce_xdl_mean_meansquare
)
add_custom_target
(
example_gemm_add_add_mean_meansquare_xdl
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp
)
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max
example_gemm_max_xdl_bf16
example_gemm_max_xdl_fp16
example_gemm_max_xdl_fp32
example_gemm_max_xdl_int8
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare
example_gemm_mean_meansquare_xdl_fp16
example_gemm_mean_meansquare_xdl_fp32
example_gemm_mean_meansquare_xdl_bf16
example_gemm_add_addsquare_xdl_int8
)
add_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
add_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
add_dependencies
(
example_gemm_reduce_xdl
add_dependencies
(
example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int4
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int4
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp
View file @
fa9da1a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_reduce_xdl_common.hpp"
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
50
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment