Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9b3c4ac4
Commit
9b3c4ac4
authored
May 14, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
1d784873
7843a8a7
Changes
83
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
86 additions
and
96 deletions
+86
-96
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+25
-24
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
...u/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
+16
-15
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
...device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+19
-18
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
..._fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+1
-4
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
...nv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
+1
-4
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+1
-4
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
...gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+3
-4
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
...pu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
...de/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
+3
-4
include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
...r_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp
+1
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
9b3c4ac4
...
...
@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx11
_supported
())
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
View file @
9b3c4ac4
...
...
@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_
navi2
_supported
()
||
ck
::
is_
navi3
_supported
())
ck
::
is_
gfx103
_supported
()
||
ck
::
is_
gfx11
_supported
())
{
bool
pass
=
true
;
pass
=
pass
&&
arg
.
K_
%
K1
==
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
9b3c4ac4
...
...
@@ -587,13 +587,14 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
BatchStrideD1s
,
BatchStrideE1
}
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"a0_grid_desc_m_k_{"
<<
a0_grid_desc_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a0_grid_desc_m_k_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b0_grid_desc_n_k_{"
<<
b0_grid_desc_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b0_grid_desc_n_k_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"d0s_grid_desc_m_n_[I0]{"
<<
d0s_grid_desc_m_n_
[
I0
].
GetLength
(
I0
)
<<
", "
<<
d0s_grid_desc_m_n_
[
I0
].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"d0s_grid_desc_m_n_[I0]{"
<<
d0s_grid_desc_m_n_
[
I0
].
GetLength
(
I0
)
<<
", "
<<
d0s_grid_desc_m_n_
[
I0
].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b1_grid_desc_n_k_{"
<<
b1_grid_desc_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_n_k_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
...
...
@@ -610,7 +611,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<<
std
::
endl
;
std
::
cout
<<
"e1_grid_desc_m_n_{"
<<
e1_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
e1_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
#endif
}
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
using
D0Layout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sLayout
>>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
9b3c4ac4
...
...
@@ -658,7 +658,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"arg.Batch_ = "
<<
arg
.
Batch_
<<
std
::
endl
;
...
...
@@ -672,13 +673,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.reduce_grid_desc_m_{ "
<<
arg
.
reduce_grid_desc_m_
.
GetLength
(
I0
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.reduce_grid_desc_m_{ "
<<
arg
.
reduce_grid_desc_m_
.
GetLength
(
I0
)
<<
"}"
<<
std
::
endl
;
}
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
9b3c4ac4
...
...
@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx11
_supported
())
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_
navi3
_supported())
if(ck::is_
gfx11
_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
9b3c4ac4
...
...
@@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
arg
.
Print
();
#endif
}
if
(
!
ck
::
is_xdl_supported
())
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
9b3c4ac4
...
...
@@ -516,7 +516,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_container_{"
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
...
...
@@ -535,7 +536,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<<
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
9b3c4ac4
...
...
@@ -644,7 +644,7 @@ struct
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
DeviceOp
{}.
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"N "
<<
arg
.
Conv_N_
<<
", "
...
...
@@ -664,9 +664,7 @@ struct
<<
arg
.
input_left_pads_
[
1
]
<<
", "
<<
std
::
endl
;
std
::
cout
<<
"InLeftPads "
<<
arg
.
input_right_pads_
[
0
]
<<
", "
<<
arg
.
input_right_pads_
[
1
]
<<
", "
<<
std
::
endl
;
}
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
...
...
@@ -684,7 +682,6 @@ struct
std
::
cout
<<
"arg.c1_grid_desc_m_n_{ "
<<
arg
.
c1_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c1_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
View file @
9b3c4ac4
...
...
@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
DeviceOp
{}.
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"N "
<<
arg
.
Conv_N_
<<
", "
...
...
@@ -634,9 +634,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
<<
arg
.
input_left_pads_
[
1
]
<<
", "
<<
std
::
endl
;
std
::
cout
<<
"InLeftPads "
<<
arg
.
input_right_pads_
[
0
]
<<
", "
<<
arg
.
input_right_pads_
[
1
]
<<
", "
<<
std
::
endl
;
}
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
...
...
@@ -651,7 +649,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
std
::
cout
<<
"arg.c0_grid_desc_m_n_{ "
<<
arg
.
c0_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c0_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
9b3c4ac4
...
...
@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
DeviceOp
{}.
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"N "
<<
arg
.
Conv_N_
<<
", "
...
...
@@ -599,9 +599,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<<
arg
.
input_left_pads_
[
1
]
<<
", "
<<
std
::
endl
;
std
::
cout
<<
"InLeftPads "
<<
arg
.
input_right_pads_
[
0
]
<<
", "
<<
arg
.
input_right_pads_
[
1
]
<<
", "
<<
std
::
endl
;
}
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
...
...
@@ -635,7 +633,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
.
GetLength
(
I5
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
9b3c4ac4
...
...
@@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
...
...
@@ -444,7 +444,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
9b3c4ac4
...
...
@@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"num_batches_of_GEMM = "
<<
arg
.
num_subbatches_
<<
std
::
endl
;
std
::
cout
<<
"a_grid_desc_k0_m_k1{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
...
...
@@ -415,7 +415,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
std
::
cout
<<
"c_grid_desc_m_n{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
View file @
9b3c4ac4
...
...
@@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_container_{"
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
...
...
@@ -1305,7 +1305,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I5
)
<<
" ) "
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
...
...
@@ -1393,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_
navi2
_supported
()
||
ck
::
is_
navi3
_supported
()))
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_
gfx103
_supported
()
||
ck
::
is_
gfx11
_supported
()))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
View file @
9b3c4ac4
...
...
@@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1{"
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
...
...
@@ -1239,7 +1239,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
<<
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
9b3c4ac4
...
...
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx11
_supported
())
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
View file @
9b3c4ac4
...
...
@@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"arg.a_grid_desc_k0_m0_m1_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
...
...
@@ -349,7 +349,6 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
))
...
...
@@ -536,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
}
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_
navi2
_supported
()
||
ck
::
is_
navi3
_supported
())
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_
gfx103
_supported
()
||
ck
::
is_
gfx11
_supported
())
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp
View file @
9b3c4ac4
...
...
@@ -168,7 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
karg
)
{
if
(
ck
::
is_
navi2
_supported
()
||
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx103
_supported
()
||
ck
::
is_
gfx11
_supported
())
{
return
GridwiseGemm
::
CheckValidity
(
karg
);
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
View file @
9b3c4ac4
...
...
@@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_
navi2
_supported
()
||
ck
::
is_
navi3
_supported
())
ck
::
is_
gfx103
_supported
()
||
ck
::
is_
gfx11
_supported
())
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
e_grid_desc_m_n_
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
9b3c4ac4
...
...
@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx11
_supported
())
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp
View file @
9b3c4ac4
...
...
@@ -510,7 +510,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
...
...
@@ -528,7 +528,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
std
::
cout
<<
"arg.reduce_grid_desc_m_{ "
<<
arg
.
reduce_grid_desc_m_
.
GetLength
(
I0
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment