Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d3469a55
Commit
d3469a55
authored
Sep 01, 2023
by
Jing Zhang
Browse files
added kpad support into v2r3
parent
f5ec04f0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
16 deletions
+40
-16
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
+6
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+34
-10
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
View file @
d3469a55
...
@@ -315,11 +315,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
...
@@ -315,11 +315,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
return
false
;
return
false
;
}
}
if
(
problem
.
K
%
K1
!=
0
)
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
problem
);
return
GridwiseGemm
::
CheckValidity
(
problem
);
}
}
...
@@ -416,7 +411,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
...
@@ -416,7 +411,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
">"
<<
">"
<<
" NumGemmKPrefetchStage: "
<<
" NumGemmKPrefetchStage: "
<<
NumGemmKPrefetchStage
<<
", "
<<
NumGemmKPrefetchStage
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
d3469a55
...
@@ -194,7 +194,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -194,7 +194,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
StrideC
{
StrideC_
},
StrideC
{
StrideC_
},
MPadded
{
CalculateMPadded
(
M_
)},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
NPadded
{
CalculateNPadded
(
N_
)},
K0
{
CalculateK0
(
K
)}
K0
{
CalculateK0
(
K
_
)}
{
{
}
}
...
@@ -383,7 +383,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -383,7 +383,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
const
index_t
KPad
=
math
::
integer_divide_ceil
(
K
,
K0PerBlock
*
K1
)
*
K0PerBlock
*
K1
;
const
index_t
num_loop
=
KPad
/
(
K0PerBlock
*
K1
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
...
@@ -840,11 +841,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
...
@@ -840,11 +841,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}
}();
}();
const
auto
K0Pad
=
math
::
integer_divide_ceil
(
K0
,
K0PerBlock
)
*
K0PerBlock
;
const
auto
KPad
=
K0Pad
*
K1Value
;
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
pad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Value
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
Pad
,
K1Value
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -852,8 +862,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
...
@@ -852,8 +862,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
else
else
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
pad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Value
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
Pad
,
K1Value
)),
make_pass_through_transform
(
M
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -874,11 +884,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
...
@@ -874,11 +884,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}
}();
}();
const
auto
K0Pad
=
math
::
integer_divide_ceil
(
K0
,
K0PerBlock
)
*
K0PerBlock
;
const
auto
KPad
=
K0Pad
*
K1Value
;
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
b_grid_desc_k
pad
_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Value
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
Pad
,
K1Value
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -886,8 +905,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
...
@@ -886,8 +905,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
else
else
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
b_grid_desc_k
pad
_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Value
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
Pad
,
K1Value
)),
make_pass_through_transform
(
N
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -947,6 +966,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
...
@@ -947,6 +966,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}
}
}
if
(
problem
.
K
%
K1
!=
0
)
{
return
false
;
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment