Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9b3c4ac4
"vscode:/vscode.git/clone" did not exist on "7b01dbee0f878f0d6a54da3566401d8441a48233"
Commit
9b3c4ac4
authored
May 14, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
1d784873
7843a8a7
Changes
83
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1161 additions
and
240 deletions
+1161
-240
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
...on/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp
..._operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
...evice_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
+16
-9
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+898
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
...device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
+24
-19
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+32
-28
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+9
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+23
-22
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+12
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
...vice/impl/device_grouped_query_attention_forward_wmma.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
...device/impl/device_multi_query_attention_forward_wmma.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+69
-67
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
...tion/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
+64
-62
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
9b3c4ac4
...
@@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx11
_supported
())
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
is_same_v
<
AccDataType
,
int32_t
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
View file @
9b3c4ac4
...
@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
...
@@ -529,7 +529,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -529,7 +529,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp
View file @
9b3c4ac4
...
@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
...
@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
...
@@ -312,7 +312,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
...
@@ -312,7 +312,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
9b3c4ac4
...
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
...
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
// check device
// check device
if
(
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx11
_supported
())
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
View file @
9b3c4ac4
...
@@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
K0PerBlock
,
K0PerBlock
,
ConvBackwardWeightSpecialization
>
{};
ConvBackwardWeightSpecialization
>
{};
static
constexpr
index_t
MaxScalarPerVectorFP32
=
4
;
static
constexpr
index_t
WorkspaceInOutScalarPerVector
=
is_same_v
<
AccDataType
,
float
>
?
math
::
min
(
CBlockTransferScalarPerVector_NWaveNPerXdl
,
MaxScalarPerVectorFP32
)
:
CBlockTransferScalarPerVector_NWaveNPerXdl
;
// Bytes per 32 lds bank: 32 * 4 bytes
// Bytes per 32 lds bank: 32 * 4 bytes
static
constexpr
auto
BankLength
=
128
;
static
constexpr
auto
BankLength
=
128
;
static
constexpr
auto
ElePerBank
=
BankLength
/
sizeof
(
ADataType
);
static
constexpr
auto
ElePerBank
=
BankLength
/
sizeof
(
ADataType
);
...
@@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
ADataType
,
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
E
DataType
,
Acc
DataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
...
@@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
BBlockLdsN1Padding
,
BBlockLdsN1Padding
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransfer
ScalarPerVector
_NWaveNPerXdl
,
WorkspaceInOut
ScalarPerVector
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
true
,
true
,
true
,
true
,
...
@@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
static
constexpr
auto
MakeElementwiseInputSequence
()
static
constexpr
auto
MakeElementwiseInputSequence
()
{
{
return
generate_sequence_v2
(
return
generate_sequence_v2
(
[
&
](
auto
)
constexpr
{
return
Number
<
CBlockTransfer
ScalarPerVector
_NWaveNPerXdl
>
{};
},
[
&
](
auto
)
constexpr
{
return
Number
<
WorkspaceInOut
ScalarPerVector
>
{};
},
Number
<
NumDTensor
+
1
>
{});
Number
<
NumDTensor
+
1
>
{});
}
}
...
@@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
using
DsGridDesc_M_N
=
decltype
(
MakeDsGridDescriptor_M_N
<
NDimSpatial
>
({},
{}));
using
DsGridDesc_M_N
=
decltype
(
MakeDsGridDescriptor_M_N
<
NDimSpatial
>
({},
{}));
using
CDGridDesc_M_N
=
decltype
(
concat_tuple
(
Tuple
<
CGridDesc_M_N
>
{},
DsGridDesc_M_N
{}));
using
CDGridDesc_M_N
=
decltype
(
concat_tuple
(
Tuple
<
CGridDesc_M_N
>
{},
DsGridDesc_M_N
{}));
using
DsGridPointerTuple
=
decltype
(
GetDsGridPointerTuple
());
using
DsGridPointerTuple
=
decltype
(
GetDsGridPointerTuple
());
using
CDDataTypes
=
decltype
(
concat_tuple
(
Tuple
<
const
E
DataType
*>
{},
DsGridPointerTuple
{}));
using
CDDataTypes
=
decltype
(
concat_tuple
(
Tuple
<
const
Acc
DataType
*>
{},
DsGridPointerTuple
{}));
using
EGridDesc_M_N
=
CGridDesc_M_N
;
using
EGridDesc_M_N
=
CGridDesc_M_N
;
static
constexpr
index_t
ClusterLengthMPerBlock
=
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
...
@@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
std
::
size_t
GetWorkspaceSizeBytes
()
const
std
::
size_t
GetWorkspaceSizeBytes
()
const
{
{
return
sizeof
(
E
DataType
)
*
ce_grid_desc_m_n_
.
GetElementSpaceSize
()
*
Conv_G_
;
return
sizeof
(
Acc
DataType
)
*
ce_grid_desc_m_n_
.
GetElementSpaceSize
()
*
Conv_G_
;
}
}
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
...
@@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
auto
launch_gemm_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_gemm_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
E
DataType
*
p_c_grid
=
type_convert
<
E
DataType
*>
(
arg
.
p_workspace_
);
Acc
DataType
*
p_c_grid
=
type_convert
<
Acc
DataType
*>
(
arg
.
p_workspace_
);
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
ce_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
ce_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
...
@@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
E
DataType
,
Acc
DataType
,
OutElementwiseOperation
,
OutElementwiseOperation
,
InElementwiseOperation
,
InElementwiseOperation
,
element_wise
::
PassThrough
,
element_wise
::
PassThrough
,
...
@@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
};
};
auto
launch_elementwise_kernel
=
[
&
]()
{
auto
launch_elementwise_kernel
=
[
&
]()
{
const
E
DataType
*
p_c_grid
=
type_convert
<
const
E
DataType
*>
(
arg
.
p_workspace_
);
const
Acc
DataType
*
p_c_grid
=
type_convert
<
const
Acc
DataType
*>
(
arg
.
p_workspace_
);
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_
.
CalculateGridSize
(
arg
.
ce_grid_desc_m_n_
)
*
arg
.
elementwise_block_2_ctile_map_
.
CalculateGridSize
(
arg
.
ce_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
arg
.
Conv_G_
;
...
@@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}
}
// vector store C matrix into global memory
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
CBlockTransferScalarPerVector_NWaveNPerXdl
==
0
))
if
(
!
(
arg
.
Conv_C_
%
CBlockTransferScalarPerVector_NWaveNPerXdl
==
0
&&
arg
.
Conv_C_
%
WorkspaceInOutScalarPerVector
==
0
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
0 → 100644
View file @
9b3c4ac4
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
View file @
9b3c4ac4
...
@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
...
@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
// check device
// check device
if
(
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx11
_supported
())
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
9b3c4ac4
...
@@ -666,7 +666,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
...
@@ -666,7 +666,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_
navi2
_supported
()
||
ck
::
is_
navi3
_supported
()))
ck
::
is_
gfx103
_supported
()
||
ck
::
is_
gfx11
_supported
()))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
9b3c4ac4
...
@@ -601,8 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
...
@@ -601,8 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_
navi2
_supported
()
||
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_
gfx103
_supported
()
||
ck
::
is_
navi3
_supported
()))
ck
::
is_
gfx11
_supported
()))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
9b3c4ac4
...
@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
// check device
if
(
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx11
_supported
())
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
View file @
9b3c4ac4
...
@@ -553,24 +553,29 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
...
@@ -553,24 +553,29 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
{
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
", "
<<
std
::
endl
;
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
"}"
<<
std
::
endl
;
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
<<
std
::
endl
;
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
", arg.e_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
std
::
endl
;
<<
"}"
<<
std
::
endl
;
#endif
std
::
cout
<<
", arg.e_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_k0_n_k1_
,
...
@@ -668,7 +673,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
...
@@ -668,7 +673,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
}
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_
navi2
_supported
()
||
ck
::
is_
navi3
_supported
())
ck
::
is_
gfx103
_supported
()
||
ck
::
is_
gfx11
_supported
())
{
{
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
9b3c4ac4
...
@@ -467,18 +467,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -467,18 +467,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
index_t
tiles
=
(
block_end
-
block_start
)
/
K_BATCH
;
{
std
::
cout
<<
"block_start: "
<<
block_start
<<
"
\n
"
index_t
tiles
=
(
block_end
-
block_start
)
/
K_BATCH
;
<<
"block_end: "
<<
block_end
<<
"
\n
"
std
::
cout
<<
"block_start: "
<<
block_start
<<
"
\n
"
<<
"tiles: "
<<
tiles
<<
std
::
endl
<<
"block_end: "
<<
block_end
<<
"
\n
"
<<
std
::
endl
;
<<
"tiles: "
<<
tiles
<<
std
::
endl
<<
std
::
endl
;
std
::
cout
<<
"KPadded: "
<<
karg
.
KPadded
<<
std
::
endl
<<
"K0Padded: "
<<
karg
.
K0Padded
<<
std
::
endl
std
::
cout
<<
"KPadded: "
<<
karg
.
KPadded
<<
std
::
endl
<<
"KBatch: "
<<
karg
.
k_batch
<<
std
::
endl
<<
"K0Padded: "
<<
karg
.
K0Padded
<<
std
::
endl
<<
"grid_size_: "
<<
karg
.
KPadded
<<
std
::
endl
;
<<
"KBatch: "
<<
karg
.
k_batch
<<
std
::
endl
#endif
<<
"grid_size_: "
<<
karg
.
KPadded
<<
std
::
endl
;
}
}
}
}
}
...
@@ -493,12 +494,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -493,12 +494,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
arg
.
karg_
.
p_c_grid
=
p_workspace
+
offset
;
arg
.
karg_
.
p_c_grid
=
p_workspace
+
offset
;
index_t
tiles
=
(
arg
.
block_end_
-
arg
.
block_start_
)
/
arg
.
karg_
.
k_batch
;
index_t
tiles
=
(
arg
.
block_end_
-
arg
.
block_start_
)
/
arg
.
karg_
.
k_batch
;
offset
+=
tiles
*
MPerBlock
*
NPerBlock
;
offset
+=
tiles
*
MPerBlock
*
NPerBlock
;
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"block_start: "
<<
arg
.
block_start_
<<
"
\n
"
{
<<
"block_end: "
<<
arg
.
block_end_
<<
"
\n
"
std
::
cout
<<
"block_start: "
<<
arg
.
block_start_
<<
"
\n
"
<<
"tiles: "
<<
tiles
<<
"
\n
"
<<
"block_end: "
<<
arg
.
block_end_
<<
"
\n
"
<<
"offset: "
<<
offset
<<
std
::
endl
;
<<
"tiles: "
<<
tiles
<<
"
\n
"
#endif
<<
"offset: "
<<
offset
<<
std
::
endl
;
}
}
}
}
}
...
@@ -816,11 +818,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -816,11 +818,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
{
"and kernel args size!"
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
<<
std
::
endl
;
"and kernel args size!"
#endif // DEBUG_LOG
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
...
@@ -832,11 +835,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -832,11 +835,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
gemm_arg
);
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
gemm_arg
);
if
(
not
group_arg_valid
)
if
(
not
group_arg_valid
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
{
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
gemm_arg
.
Print
();
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
#endif // DEBUG_LOG
gemm_arg
.
Print
();
}
}
}
supported
=
supported
&&
group_arg_valid
;
supported
=
supported
&&
group_arg_valid
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
9b3c4ac4
...
@@ -375,7 +375,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
...
@@ -375,7 +375,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
std
::
vector
<
const
void
*>&
/* p_Bs */
,
std
::
vector
<
const
void
*>&
/* p_Bs */
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
/* p_Ds */
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
/* p_Ds */
,
std
::
vector
<
void
*>&
/* p_Es */
,
std
::
vector
<
void
*>&
/* p_Es */
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
const
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
CDEElementwiseOperation
cde_element_op
,
...
@@ -620,11 +620,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
...
@@ -620,11 +620,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
GridwiseGemm
::
template
CheckTensorTransfersValidity
<
ALayout
,
BLayout
,
ELayout
>(
GridwiseGemm
::
template
CheckTensorTransfersValidity
<
ALayout
,
BLayout
,
ELayout
>(
M
,
N
,
K
)))
M
,
N
,
K
)))
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"The provided GEMM problem size (M,N,K) ["
<<
M
<<
","
<<
N
<<
","
<<
K
{
<<
"] are not supported by current template parameters!"
std
::
cout
<<
"The provided GEMM problem size (M,N,K) ["
<<
M
<<
","
<<
N
<<
","
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
<<
K
<<
"] are not supported by current template parameters!"
#endif
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
}
supported
=
false
;
supported
=
false
;
}
}
}
}
...
@@ -641,7 +643,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
...
@@ -641,7 +643,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
std
::
vector
<
GemmDesc
>
&
gemm_descs
,
AElementwiseOperation
a_elementwise_op
,
AElementwiseOperation
a_elementwise_op
,
BElementwiseOperation
b_elementwise_op
,
BElementwiseOperation
b_elementwise_op
,
CDEElementwiseOperation
cde_elementwise_op
)
CDEElementwiseOperation
cde_elementwise_op
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
9b3c4ac4
...
@@ -514,28 +514,29 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -514,28 +514,29 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_ak0_m_ak1_{"
{
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_ak0_m_ak1_{"
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
<<
", "
<<
"}"
;
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.b_grid_desc_bk0_n_bk1_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
std
::
cout
<<
", arg.b_grid_desc_bk0_n_bk1_{"
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
", "
<<
"}"
;
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.e_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
", arg.e_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
std
::
endl
;
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
#endif
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_m_k_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_m_k_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_n_k_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_n_k_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
9b3c4ac4
...
@@ -529,11 +529,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -529,11 +529,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
{
"and kernel args size!"
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
<<
std
::
endl
;
"and kernel args size!"
#endif // DEBUG_LOG
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
...
@@ -544,11 +545,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -544,11 +545,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
if
(
not
group_arg_valid
)
if
(
not
group_arg_valid
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
{
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
a
.
Print
();
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
#endif // DEBUG_LOG
a
.
Print
();
}
}
}
supported
=
supported
&&
group_arg_valid
;
supported
=
supported
&&
group_arg_valid
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
View file @
9b3c4ac4
...
@@ -596,7 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
...
@@ -596,7 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
{
if
(
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx11
_supported
())
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
...
@@ -958,7 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
...
@@ -958,7 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
#if 0
#if 0
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& arg)
{
{
if(ck::is_
navi3
_supported())
if(ck::is_
gfx11
_supported())
{
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
View file @
9b3c4ac4
...
@@ -594,7 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
...
@@ -594,7 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
{
if
(
ck
::
is_
navi3
_supported
())
if
(
ck
::
is_
gfx11
_supported
())
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
...
@@ -950,7 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
...
@@ -950,7 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
#if 0
#if 0
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& arg)
{
{
if(ck::is_
navi3
_supported())
if(ck::is_
gfx11
_supported())
{
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
{
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
9b3c4ac4
...
@@ -260,7 +260,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
...
@@ -260,7 +260,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
};
};
// Grouped Rows of column-vectors WGP mapping
// Grouped Rows of column-vectors WGP mapping
// Optimized for
MI300
-like multipe-die chip
// Optimized for
gfx94x
-like multipe-die chip
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
>
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
9b3c4ac4
...
@@ -935,12 +935,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -935,12 +935,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
{
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
std
::
endl
;
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -952,12 +952,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -952,12 +952,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
{
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
std
::
endl
;
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -971,12 +971,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -971,12 +971,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
{
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
<<
karg
.
K
<<
"
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -995,13 +995,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -995,13 +995,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg K ("
<<
karg
.
K
{
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
") value is not a multiple of
ABlockTransferSrcScalarPerVector
(
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
<<
ABlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1009,13 +1009,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1009,13 +1009,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg M ("
<<
karg
.
M
{
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
") value is not a multiple of
ABlockTransferSrcScalarPerVector
(
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
<<
ABlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1024,13 +1024,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1024,13 +1024,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg N ("
<<
karg
.
N
{
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
") value is not a multiple of
BBlockTransferSrcScalarPerVector
(
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
<<
BBlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1038,13 +1038,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1038,13 +1038,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg K ("
<<
karg
.
K
{
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
") value is not a multiple of
BBlockTransferSrcScalarPerVector
(
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
<<
BBlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1053,14 +1053,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1053,14 +1053,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg N ("
<<
karg
.
N
{
<<
") value is not a multiple of "
std
::
cout
<<
"Arg N ("
<<
karg
.
N
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
") value is not a multiple of "
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
#endif // DEBUG_LOG
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1068,25 +1069,26 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1068,25 +1069,26 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg M ("
<<
karg
.
M
{
<<
") value is not a multiple of "
std
::
cout
<<
"Arg M ("
<<
karg
.
M
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
") value is not a multiple of "
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
#endif // DEBUG_LOG
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
}
}
if
constexpr
(
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
if
constexpr
(
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
" KBatch: "
<<
karg
.
KBatch
<<
" > 1 is not support yet"
<<
__FILE__
<<
":"
{
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
std
::
cout
<<
" KBatch: "
<<
karg
.
KBatch
<<
" > 1 is not support yet"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
if
(
karg
.
KBatch
>
1
)
if
(
karg
.
KBatch
>
1
)
{
{
return
false
;
return
false
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
View file @
9b3c4ac4
...
@@ -1113,12 +1113,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1113,12 +1113,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
{
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
std
::
endl
;
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1130,12 +1130,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1130,12 +1130,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
{
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
std
::
endl
;
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1149,12 +1149,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1149,12 +1149,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
{
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
<<
karg
.
K
<<
"
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1173,13 +1173,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1173,13 +1173,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg K ("
<<
karg
.
K
{
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
") value is not a multiple of
ABlockTransferSrcScalarPerVector
(
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
<<
ABlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1187,13 +1187,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1187,13 +1187,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg M ("
<<
karg
.
M
{
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
") value is not a multiple of
ABlockTransferSrcScalarPerVector
(
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
<<
ABlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1202,13 +1202,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1202,13 +1202,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg N ("
<<
karg
.
N
{
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
") value is not a multiple of
BBlockTransferSrcScalarPerVector
(
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
<<
BBlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1216,13 +1216,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1216,13 +1216,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg K ("
<<
karg
.
K
{
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
") value is not a multiple of
BBlockTransferSrcScalarPerVector
(
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
<<
BBlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1231,14 +1231,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1231,14 +1231,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg N ("
<<
karg
.
N
{
<<
") value is not a multiple of "
std
::
cout
<<
"Arg N ("
<<
karg
.
N
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
") value is not a multiple of "
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
#endif // DEBUG_LOG
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
}
}
...
@@ -1246,14 +1247,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1246,14 +1247,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"Arg M ("
<<
karg
.
M
{
<<
") value is not a multiple of "
std
::
cout
<<
"Arg M ("
<<
karg
.
M
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
") value is not a multiple of "
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
#endif // DEBUG_LOG
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
}
}
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment