Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e10cbe9e
Commit
e10cbe9e
authored
Sep 21, 2023
by
Alan Turner
Browse files
Add constexpr IsSupported
parent
d9676215
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
111 additions
and
58 deletions
+111
-58
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+111
-58
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
e10cbe9e
...
...
@@ -611,6 +611,99 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
true
;
}
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
{
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
)
{
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
)
{
// FIXME: not rigorous
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
)
{
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
)
{
// FIXME: not rigorous
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B1
if
constexpr
(
is_same_v
<
B1Layout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
B1Layout
,
Col
>
)
{
// FIXME: not rigorous
if
(
NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of C
if
constexpr
(
is_same_v
<
CLayout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
CLayout
,
Col
>
)
{
// FIXME: not rigorous
if
(
MRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
...
...
@@ -625,52 +718,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
static
constexpr
bool
IsSupported
(
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
)
{
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
return
true
;
arg
.
block_2_ctile_map_
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
);
}
// polymorphic
...
...
@@ -861,7 +914,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
Spec
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
...
...
@@ -928,8 +981,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemm
Spec
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
Spec
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
AElementwiseOperation
a_element_op
;
...
...
@@ -952,23 +1005,27 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
block_2_ctile_map
{
GridwiseGemm
Spec
::
MakeDefaultBlock2CTileMap
(
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
Spec
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
Spec
::
CalculateHasMainKBlockLoop
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
Spec
::
CheckValidity
(
is_valid
{
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
)}
block_2_ctile_map
)
and
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
{
}
...
...
@@ -1001,17 +1058,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
ADataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
)
{
assert
(
desc
.
is_valid
and
IsSupported
(
desc
.
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
desc
.
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
desc
.
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
desc
.
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
desc
.
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)));
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemmSpec
::
GetSharedMemoryNumberOfByte
()];
assert
(
desc
.
is_valid
);
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
AccElementwiseOperation
acc_element_op
{
scale
};
if
(
desc
.
has_main_k_block_loop
)
{
Desc
::
GridwiseGemm
Spec
::
template
Run
<
true
>(
p_a_grid
,
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
...
...
@@ -1030,7 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
else
{
Desc
::
GridwiseGemm
Spec
::
template
Run
<
false
>(
p_a_grid
,
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment