Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
148d9e57
Commit
148d9e57
authored
May 04, 2023
by
Po-Yen, Chen
Browse files
Move kernel arg type definition into GridwiseGemm
parent
affdca9d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
86 additions
and
104 deletions
+86
-104
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+4
-100
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+82
-4
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
148d9e57
...
...
@@ -130,106 +130,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
LoopSched
,
PipelineVer
>
;
using
AGridDesc_AK0_M_AK1
=
decltype
(
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
));
// Argument
struct
Argument
:
public
BaseArgument
{
__host__
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
GridwiseGemm
::
CalculateMPadded
(
M_
)},
NPadded
{
GridwiseGemm
::
CalculateNPadded
(
N_
)},
KPadded
{
GridwiseGemm
::
CalculateKPadded
(
K_
)},
AK0
{
GridwiseGemm
::
CalculateAK0
(
K_
)},
BK0
{
GridwiseGemm
::
CalculateBK0
(
K_
)},
a_grid_desc_ak0_m_ak1
{
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
M_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
K_
,
GridwiseGemm
::
CalculateKPadded
(
K_
),
StrideA_
,
GridwiseGemm
::
CalculateAK0
(
K_
))},
b_grid_desc_bk0_n_bk1
{
GridwiseGemm
::
MakeBGridDescriptor_BK0_N_BK1
(
K_
,
GridwiseGemm
::
CalculateKPadded
(
K_
),
N_
,
GridwiseGemm
::
CalculateNPadded
(
N_
),
StrideB_
,
GridwiseGemm
::
CalculateBK0
(
K_
))},
c_grid_desc_m_n
{
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
N_
,
GridwiseGemm
::
CalculateNPadded
(
N_
),
StrideC_
)}
{
}
__host__
__device__
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
"}"
<<
std
::
endl
;
}
__host__
__device__
Argument
(
const
Argument
&
)
=
default
;
__host__
__device__
~
Argument
()
override
{}
// private:
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
CDataType
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
};
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
void
Print
(
const
Argument
&
karg
)
{
karg
.
Print
();
}
float
Run
(
const
Argument
&
karg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
...
...
@@ -253,16 +158,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
Argument
,
true
>
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
Argument
,
false
>
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
148d9e57
...
...
@@ -17,12 +17,12 @@
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
Argument
,
bool
HasMainKBlockLoop
>
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdl_cshuffle_v1_simplified
(
Argument
karg
)
kernel_gemm_xdl_cshuffle_v1_simplified
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
...
...
@@ -383,6 +383,85 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
));
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
{
__host__
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
KPadded
{
CalculateKPadded
(
K_
)},
AK0
{
CalculateAK0
(
K_
)},
BK0
{
CalculateBK0
(
K_
)},
a_grid_desc_ak0_m_ak1
{
MakeAGridDescriptor_AK0_M_AK1
(
M_
,
CalculateMPadded
(
M_
),
K_
,
CalculateKPadded
(
K_
),
StrideA_
,
CalculateAK0
(
K_
))},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
K_
,
CalculateKPadded
(
K_
),
N_
,
CalculateNPadded
(
N_
),
StrideB_
,
CalculateBK0
(
K_
))},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
M_
,
CalculateMPadded
(
M_
),
N_
,
CalculateNPadded
(
N_
),
StrideC_
)}
{
}
__host__
__device__
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
"}"
<<
std
::
endl
;
}
__host__
__device__
Argument
(
const
Argument
&
)
=
default
;
__host__
__device__
~
Argument
()
override
{}
// private:
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
...
...
@@ -447,7 +526,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Argument
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
...
...
@@ -590,7 +668,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
print_bytes
(
memory
,
sizeof
(
T
));
}
template
<
bool
HasMainKBlockLoop
,
typename
Argument
>
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared
)
{
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment