Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dddc2115
"vscode:/vscode.git/clone" did not exist on "48c56dd360089ea58bcd4557f2f19c0056244212"
Commit
dddc2115
authored
Aug 08, 2023
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
6e01019b
08eb1769
Changes
139
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
340 additions
and
339 deletions
+340
-339
example/40_conv2d_fwd_quantization/CMakeLists.txt
example/40_conv2d_fwd_quantization/CMakeLists.txt
+14
-16
example/41_grouped_conv_conv_fwd/CMakeLists.txt
example/41_grouped_conv_conv_fwd/CMakeLists.txt
+12
-4
example/42_groupnorm/CMakeLists.txt
example/42_groupnorm/CMakeLists.txt
+5
-3
example/43_splitk_gemm_bias_e_permute/CMakeLists.txt
example/43_splitk_gemm_bias_e_permute/CMakeLists.txt
+6
-2
example/44_elementwise_permute/CMakeLists.txt
example/44_elementwise_permute/CMakeLists.txt
+4
-2
example/46_gemm_add_multiply/CMakeLists.txt
example/46_gemm_add_multiply/CMakeLists.txt
+6
-2
example/48_pool3d_fwd/CMakeLists.txt
example/48_pool3d_fwd/CMakeLists.txt
+3
-2
example/49_maxpool2d_bwd/CMakeLists.txt
example/49_maxpool2d_bwd/CMakeLists.txt
+9
-3
example/50_put_element/CMakeLists.txt
example/50_put_element/CMakeLists.txt
+3
-1
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
...r_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
+6
-9
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+78
-81
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+147
-187
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
...ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
+22
-13
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp
...operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp
+2
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp
...nsor_operation_instance/gpu/batched_gemm_bias_permute.hpp
+2
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
...n_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
+8
-4
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
...brary/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
+2
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
...ry/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
+8
-5
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
...nsor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
+2
-1
No files found.
example/40_conv2d_fwd_quantization/CMakeLists.txt
View file @
dddc2115
...
@@ -10,21 +10,19 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -10,21 +10,19 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
# Conv perlayer quantization
add_example_executable
(
example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp
)
# Conv perchannel quantization
if
(
DL_KERNELS
)
add_example_executable
(
example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_
quantization
_int8.cpp
)
# Conv perlayer
quantization
add_example_executable
(
example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp
)
# Conv
+ bias + relu perlayer
quantization
# Conv
perchannel
quantization
add_example_executable
(
example_conv2d_fwd_dl_
bias_relu_perlayer
_quantization_int8 conv2d_fwd_dl_
bias_relu_perlayer
_quantization_int8.cpp
)
add_example_executable
(
example_conv2d_fwd_dl_
perchannel
_quantization_int8 conv2d_fwd_dl_
perchannel
_quantization_int8.cpp
)
# Conv + bias + relu perlayer quantization
# Conv +
bias
+
relu
per
channel
quantization
add_example_executable
(
example_conv2d_fwd_dl_
bias
_
relu
_
per
layer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_
quantization
_int8.cpp
)
add_example_executable
(
example_conv2d_fwd_dl_
bias
_
relu
_
perchannel
_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_
quantization
_int8.cpp
)
# Conv +
bias
+
relu
perchannel
quantization
add_example_executable
(
example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp
)
# Conv + bias + tanh perlayer quantization
# Conv + bias + tanh perlayer quantization
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp
)
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp
)
# Conv + bias + tanh perchannel quantization
# Conv +
bias
+
tanh
perchannel
quantization
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_
bias
_
tanh
_
perchannel
_
quantization
_int8.cpp
)
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp
)
endif
(
)
endif
()
endif
()
\ No newline at end of file
example/41_grouped_conv_conv_fwd/CMakeLists.txt
View file @
dddc2115
...
@@ -3,9 +3,15 @@ list(APPEND gpu_list2 gfx908 gfx90a)
...
@@ -3,9 +3,15 @@ list(APPEND gpu_list2 gfx908 gfx90a)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list1 AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list1 AND target EQUAL 0
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
...
@@ -14,5 +20,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -14,5 +20,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
endforeach
()
endforeach
()
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp
)
endif
()
endif
()
endif
()
example/42_groupnorm/CMakeLists.txt
View file @
dddc2115
add_example_executable
(
example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp
)
add_example_executable
(
example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp
)
add_example_executable
(
example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp
)
add_example_executable
(
example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp
)
add_example_executable
(
example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp
)
endif
()
example/43_splitk_gemm_bias_e_permute/CMakeLists.txt
View file @
dddc2115
add_example_executable
(
example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp
)
add_example_executable
(
example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp
)
endif
()
example/44_elementwise_permute/CMakeLists.txt
View file @
dddc2115
add_example_executable
(
example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp
)
add_example_executable
(
example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp
)
add_example_executable
(
example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp
)
endif
()
example/46_gemm_add_multiply/CMakeLists.txt
View file @
dddc2115
add_example_executable
(
example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp
)
if
(
DL_KERNELS
)
add_example_executable
(
example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp
)
endif
()
add_example_executable
(
example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp
)
endif
()
example/48_pool3d_fwd/CMakeLists.txt
View file @
dddc2115
add_example_executable
(
example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp
)
endif
()
example/49_maxpool2d_bwd/CMakeLists.txt
View file @
dddc2115
add_example_executable
(
example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp
)
add_example_executable
(
example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp
)
add_example_executable
(
example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp
)
endif
()
example/50_put_element/CMakeLists.txt
View file @
dddc2115
add_example_executable
(
example_put_element_fp16 put_element_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_put_element_fp16 put_element_fp16.cpp
)
endif
()
include/ck/ck.hpp
View file @
dddc2115
...
@@ -198,7 +198,7 @@
...
@@ -198,7 +198,7 @@
#define CK_WORKAROUND_SWDEV_388832 1
#define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_
SWDEV_3318619 0
#define CK_WORKAROUND_
GITHUB_ISSUE_824 1
// flag to enable (1) or disable (0) the debugging output in some kernels
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
#define DEBUG_LOG 0
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
View file @
dddc2115
...
@@ -27,15 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
...
@@ -27,15 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in
,
MakeArgumentPointer
(
const
void
*
p_in
,
void
*
p_wei
,
void
*
p_wei
,
const
void
*
p_out
,
const
void
*
p_out
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
dddc2115
...
@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument
(
const
InDataType
*
p_in_grid
,
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_out_grid
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*a_g_n_c_wis_strides*/
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*b_g_k_c_xs_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*e_g_n_k_wos_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*input_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*output_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_
{
out_element_op
},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
c_element_op_
{
in_element_op
},
Conv_G_
{
G
},
Conv_G_
{
a_g_n_c_wis_lengths
[
0
]
},
Conv_N_
{
N
},
Conv_N_
{
a_g_n_c_wis_lengths
[
1
]
},
Conv_K_
{
K
},
Conv_K_
{
b_g_k_c_xs_lengths
[
1
]
},
Conv_C_
{
C
},
Conv_C_
{
a_g_n_c_wis_lengths
[
2
]
},
input_spatial_lengths_
{
input_spatial_lengths
},
input_spatial_lengths_
{},
filter_spatial_lengths_
{
filter_spatial_lengths
},
filter_spatial_lengths_
{},
output_spatial_lengths_
{
output_spatial_lengths
},
output_spatial_lengths_
{},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
k_batch_
{
split_k
}
{
{
constexpr
index_t
spatial_offset
=
3
;
std
::
copy
(
begin
(
a_g_n_c_wis_lengths
)
+
spatial_offset
,
end
(
a_g_n_c_wis_lengths
),
begin
(
input_spatial_lengths_
));
std
::
copy
(
begin
(
b_g_k_c_xs_lengths
)
+
spatial_offset
,
end
(
b_g_k_c_xs_lengths
),
begin
(
filter_spatial_lengths_
));
std
::
copy
(
begin
(
e_g_n_k_wos_lengths
)
+
spatial_offset
,
end
(
e_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
const
auto
descs
=
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
N
,
Conv_N_
,
K
,
Conv_K_
,
C
,
C
onv_C_
,
input_spatial_lengths
,
input_spatial_lengths
_
,
filter_spatial_lengths
,
filter_spatial_lengths
_
,
output_spatial_lengths
,
output_spatial_lengths
_
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// A/B/C Batch Stride
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
N
*
K
*
Conv_N_
*
Conv_K_
*
std
::
accumulate
(
begin
(
output_spatial_lengths
),
std
::
accumulate
(
begin
(
output_spatial_lengths
_
),
end
(
output_spatial_lengths
),
end
(
output_spatial_lengths
_
),
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
N
*
C
*
Conv_N_
*
Conv_C_
*
std
::
accumulate
(
begin
(
input_spatial_lengths
),
std
::
accumulate
(
begin
(
input_spatial_lengths
_
),
end
(
input_spatial_lengths
),
end
(
input_spatial_lengths
_
),
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
K
*
C
*
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
std
::
accumulate
(
begin
(
filter_spatial_lengths
_
),
end
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
_
),
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
}
}
...
@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const
index_t
Conv_K_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
const
index_t
Conv_C_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
...
@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
static
auto
WeiDataType
*
p_wei_grid
,
MakeArgument
(
const
InDataType
*
p_in_grid
,
const
OutDataType
*
p_out_grid
,
WeiDataType
*
p_wei_grid
,
const
ck
::
index_t
G
,
const
OutDataType
*
p_out_grid
,
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
InElementwiseOperation
in_element_op
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
WeiElementwiseOperation
wei_element_op
,
InElementwiseOperation
in_element_op
,
OutElementwiseOperation
out_element_op
,
WeiElementwiseOperation
wei_element_op
,
ck
::
index_t
split_k
)
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
{
{
return
Argument
{
p_in_grid
,
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_wei_grid
,
p_out_grid
,
p_out_grid
,
G
,
a_g_n_c_wis_lengths
,
// input
N
,
a_g_n_c_wis_strides
,
K
,
b_g_k_c_xs_lengths
,
// weight
C
,
b_g_k_c_xs_strides
,
input_spatial_lengths
,
e_g_n_k_wos_lengths
,
// output
filter_spatial_lengths
,
e_g_n_k_wos_strides
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
void
*
p_out_grid
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
G
,
a_g_n_c_wis_lengths
,
// input
N
,
a_g_n_c_wis_strides
,
K
,
b_g_k_c_xs_lengths
,
// weight
C
,
b_g_k_c_xs_strides
,
input_spatial_lengths
,
e_g_n_k_wos_lengths
,
// output
filter_spatial_lengths
,
e_g_n_k_wos_strides
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
dddc2115
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
View file @
dddc2115
...
@@ -16,7 +16,7 @@ namespace ck {
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef __bf16__
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -36,7 +36,8 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
...
@@ -36,7 +36,8 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#endif
#ifdef __fp16__
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -56,7 +57,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
...
@@ -56,7 +57,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#endif
#ifdef __fp32__
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -76,7 +78,8 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
...
@@ -76,7 +78,8 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#endif
#ifdef __int8__
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
...
@@ -120,7 +123,7 @@ void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
...
@@ -120,7 +123,7 @@ void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
...
@@ -151,7 +154,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
...
@@ -151,7 +154,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef __fp32__
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
CDataType
,
float
>
)
is_same_v
<
CDataType
,
float
>
)
{
{
...
@@ -176,8 +179,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
...
@@ -176,8 +179,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
#endif
is_same_v
<
CDataType
,
half_t
>
)
#ifdef __fp16__
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -200,8 +205,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
...
@@ -200,8 +205,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
#endif
is_same_v
<
CDataType
,
bhalf_t
>
)
#ifdef __bf16__
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -224,8 +231,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
...
@@ -224,8 +231,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
#endif
is_same_v
<
CDataType
,
int8_t
>
)
#ifdef __int8__
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
int8_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
...
@@ -248,7 +257,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
...
@@ -248,7 +257,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
}
}
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp
View file @
dddc2115
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
using
CDE0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
CDE0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
#ifdef __fp16__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -137,3 +137,4 @@ struct DeviceOperationInstanceFactory<
...
@@ -137,3 +137,4 @@ struct DeviceOperationInstanceFactory<
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp
View file @
dddc2115
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -91,3 +91,4 @@ struct DeviceOperationInstanceFactory<
...
@@ -91,3 +91,4 @@ struct DeviceOperationInstanceFactory<
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
View file @
dddc2115
...
@@ -16,7 +16,7 @@ namespace ck {
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef __fp16__
void
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
...
@@ -58,7 +58,8 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_
...
@@ -58,7 +58,8 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
#endif
#ifdef __bf16__
void
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
...
@@ -100,7 +101,7 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf
...
@@ -100,7 +101,7 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
#endif
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
...
@@ -147,7 +148,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -147,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef __fp16__
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
&&
Acc0BiasDataType
::
Size
()
==
1
&&
Acc0BiasDataType
::
Size
()
==
1
&&
...
@@ -164,6 +165,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -164,6 +165,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
#ifdef __bf16__
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
&&
Acc0BiasDataType
::
Size
()
==
1
&&
Acc0BiasDataType
::
Size
()
==
1
&&
...
@@ -180,6 +183,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -180,6 +183,7 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
View file @
dddc2115
...
@@ -16,7 +16,7 @@ namespace ck {
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef __fp16__
void
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
void
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
Row
,
Col
,
Col
,
...
@@ -111,3 +111,4 @@ struct DeviceOperationInstanceFactory<
...
@@ -111,3 +111,4 @@ struct DeviceOperationInstanceFactory<
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
View file @
dddc2115
...
@@ -19,7 +19,7 @@ namespace ck {
...
@@ -19,7 +19,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef __fp16__
void
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances
(
void
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultiD
<
Col
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultiD
<
Col
,
Row
,
Row
,
...
@@ -123,7 +123,8 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instan
...
@@ -123,7 +123,8 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instan
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#ifdef __int8__
void
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances
(
void
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultiD
<
Col
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultiD
<
Col
,
Row
,
Row
,
...
@@ -227,7 +228,7 @@ void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances
...
@@ -227,7 +228,7 @@ void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
ELayout
,
typename
ELayout
,
...
@@ -262,7 +263,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
...
@@ -262,7 +263,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef __fp16__
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
is_same_v
<
EDataType
,
half_t
>
)
{
{
...
@@ -295,6 +296,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
...
@@ -295,6 +296,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
EDataType
,
int8_t
>
)
is_same_v
<
EDataType
,
int8_t
>
)
{
{
...
@@ -327,7 +330,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
...
@@ -327,7 +330,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
View file @
dddc2115
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<
...
@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment