Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2bb9444f
Commit
2bb9444f
authored
Jun 23, 2023
by
Paul
Browse files
Allow storing stateful lambdas
parent
7838a9a8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
99 deletions
+57
-99
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+57
-99
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
2bb9444f
...
@@ -580,87 +580,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -580,87 +580,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
false
;
return
false
;
}
}
// check vector load/store
return
IsSupported
(
arg
.
MRaw_
,
arg
.
NRaw_
,
arg
.
KRaw_
)
and
{
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of Ds
// only support RowMajor for now
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
all_valid
=
false
;
}
});
if
(
!
all_valid
)
{
return
false
;
}
// check vector store of E
// only support RowMajor for now
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
@@ -812,10 +733,27 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -812,10 +733,27 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
Block2ETileMap
block_2_etile_map
;
Block2ETileMap
block_2_etile_map
;
// for checking vector load/store
index_t
MRaw
;
index_t
NRaw
;
index_t
KRaw
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
CDEElementwiseOperation
cde_element_op
;
bool
has_main_k_block_loop
=
true
;
bool
has_main_k_block_loop
=
true
;
bool
is_valid
=
false
;
bool
is_valid
=
false
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
)
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CDEElementwiseOperation
cde_element_op_
)
:
a_grid_desc_ak0_m_ak1
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
:
a_grid_desc_ak0_m_ak1
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a
))},
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a
))},
b_grid_desc_bk0_n_bk1
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_bk0_n_bk1
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
...
@@ -834,24 +772,44 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -834,24 +772,44 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
))},
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
))},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
MRaw
{
e
.
GetLength
(
I0
)},
NRaw
{
e
.
GetLength
(
I1
)},
KRaw
{
a
.
GetLength
(
I1
)},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
cde_element_op
{
cde_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
is_valid
{
GridwiseGemm
::
CheckValidity
(
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a
)),
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a
)),
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b
),
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b
),
transform_tuples
(
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
},
ds
),
ds
),
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
),
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
),
block_2_etile_map
)
and
IsSupported
(
e
.
GetLength
(
I0
),
e
.
GetLength
(
I1
),
a
.
GetLength
(
I1
))}
block_2_etile_map
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
)}
{
}
constexpr
bool
IsValid
()
const
{
{
return
is_valid
;
}
}
};
};
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
)
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
CDEElementwiseOperation
cde_element_op
=
CDEElementwiseOperation
{})
{
{
return
Descriptor
<
ADesc
,
BDesc
,
DsDesc
,
EDesc
>
(
a
,
b
,
ds
,
e
);
return
Descriptor
<
ADesc
,
BDesc
,
DsDesc
,
EDesc
>
(
a
,
b
,
ds
,
e
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
}
template
<
class
Desc
,
class
DsPointer
>
template
<
class
Desc
,
class
DsPointer
>
...
@@ -870,9 +828,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -870,9 +828,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
p_ds_grid
,
p_ds_grid
,
p_e_grid
,
p_e_grid
,
p_shared_block
,
p_shared_block
,
AElementwiseOperation
{}
,
desc
.
a_element_op
,
BElementwiseOperation
{}
,
desc
.
b_element_op
,
CDEElementwiseOperation
{}
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -886,9 +844,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -886,9 +844,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
p_ds_grid
,
p_ds_grid
,
p_e_grid
,
p_e_grid
,
p_shared_block
,
p_shared_block
,
AElementwiseOperation
{}
,
desc
.
a_element_op
,
BElementwiseOperation
{}
,
desc
.
b_element_op
,
CDEElementwiseOperation
{}
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment