Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
97648ccd
Commit
97648ccd
authored
May 31, 2023
by
Adam Osewski
Browse files
Use of PadTensorDescriptor for grid desc creation.
parent
0eff71a4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
102 deletions
+35
-102
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_direct_c_write_out.hpp
...mpl/device_grouped_gemm_xdl_splitk_direct_c_write_out.hpp
+2
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out.hpp
...u/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out.hpp
+33
-92
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_direct_c_write_out.hpp
View file @
97648ccd
...
...
@@ -256,9 +256,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
m_padded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
index_t
n_padded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_c
);
...
...
@@ -285,9 +282,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
stride_a
,
stride_b
,
stride_c
,
m_padded
,
n_padded
,
k_padded
,
k0
,
K_BATCH
};
...
...
@@ -311,7 +305,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
karg
.
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
...
...
@@ -330,7 +323,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
karg
.
KPadded
=
k_padded
;
karg
.
K0
=
k0
;
karg
.
k_batch
=
K_BATCH
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
grouped_block_2_ctile_map
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_direct_c_write_out.hpp
View file @
97648ccd
...
...
@@ -97,10 +97,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
KPerBlock
=
K1Value
*
K0PerBlock
;
static
constexpr
auto
gemm_padder
=
tensor_operation
::
device
::
GemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K1
*
K0PerBlock
};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
...
...
@@ -116,9 +112,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
K0
;
index_t
k_batch
;
...
...
@@ -131,9 +124,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
k_batch_
)
:
p_a_grid
(
p_a_grid_
),
...
...
@@ -145,9 +135,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
StrideA
(
StrideA_
),
StrideB
(
StrideB_
),
StrideC
(
StrideC_
),
MPadded
(
MPadded_
),
NPadded
(
NPadded_
),
KPadded
(
KPadded_
),
K0
(
K0_
),
k_batch
(
k_batch_
)
{
...
...
@@ -162,9 +149,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"K0:"
<<
K0
<<
", "
<<
"KB:"
<<
k_batch
<<
"}"
<<
std
::
endl
;
}
...
...
@@ -300,13 +284,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
}
}
__host__
__device__
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0
,
index_t
KPad
)
__host__
__device__
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
...
...
@@ -319,43 +298,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
}
}();
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
using
DoPads
=
Sequence
<
tensor_operation
::
device
::
GemmPadM
<
GemmSpec
>::
PadM
,
true
>
;
const
auto
a_grid_desc_mpad_kpad
=
tensor_operation
::
device
::
PadTensorDescriptor
(
a_grid_desc_m_k
,
make_tuple
(
MPerBlock
,
K0
*
K1
),
DoPads
{});
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
a_grid_desc_m
pad
_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_
right_pad_transform
(
M
,
MPad
-
M
)),
make_
pass_through_transform
(
a_grid_desc_mpad_kpad
.
GetLength
(
I0
)
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_KBatch_K0_N_K1
(
index_t
K
,
index_t
NPad
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0
,
index_t
KPad
)
__host__
__device__
static
auto
MakeBGridDescriptor_KBatch_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0
)
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
...
@@ -368,35 +324,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
}
}();
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
using
DoPads
=
Sequence
<
true
,
tensor_operation
::
device
::
GemmPadN
<
GemmSpec
>::
PadN
>
;
const
auto
b_grid_desc_kpad_npad
=
tensor_operation
::
device
::
PadTensorDescriptor
(
b_grid_desc_k_n
,
make_tuple
(
K0
*
K1
,
NPerBlock
),
DoPads
{});
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
b_grid_desc_kpad_n
pad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_pass_through_transform
(
N
)),
make_pass_through_transform
(
b_grid_desc_kpad_npad
.
GetLength
(
I1
)
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
...
...
@@ -411,7 +349,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
}
}();
return
gemm_padder
.
PadCDescriptor_M_N
(
c_grid_desc_m_n
);
using
DoPads
=
Sequence
<
tensor_operation
::
device
::
GemmPadM
<
GemmSpec
>::
PadM
,
tensor_operation
::
device
::
GemmPadN
<
GemmSpec
>::
PadN
>
;
return
tensor_operation
::
device
::
PadTensorDescriptor
(
c_grid_desc_m_n
,
make_tuple
(
MPerBlock
,
NPerBlock
),
DoPads
{});
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
...
...
@@ -615,10 +556,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
a_b_k0_m_k1_grid_desc
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
const
auto
b_b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
const
auto
a_b_k0_m_k1_grid_desc
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
);
const
auto
b_b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_grid_desc_m_n
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment