Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f4ea00fc
Commit
f4ea00fc
authored
May 07, 2023
by
Po-Yen, Chen
Browse files
Make sure methods are only invoked on right place
parent
880bbc45
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
51 deletions
+20
-51
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+10
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+10
-46
No files found.
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
f4ea00fc
...
...
@@ -217,7 +217,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
DeviceOp
::
MakeDescriptor_M
({
M_
,
N_
},
{
I1
,
StrideC_
},
grid_size
,
BlockSize
);
}
p_aux_2_grid_
=
p_workspace
+
Parent
::
c_grid_desc_m_n
.
GetElementSpaceSize
();
p_aux_2_grid_
=
p_workspace
+
Get
C
ElementSpaceSize
(
M_
,
N_
,
StrideC_
);
}
// private:
...
...
@@ -561,6 +561,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return
str
.
str
();
}
static
std
::
size_t
GetCElementSpaceSize
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
GridwiseGemm
::
CalculateMPadded
(
M
),
N
,
GridwiseGemm
::
CalculateNPadded
(
N
),
StrideC
);
return
c_grid_desc_m_n
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceSize
(
index_t
M
,
index_t
N
,
[[
maybe_unused
]]
index_t
K
,
...
...
@@ -568,10 +576,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[
maybe_unused
]]
index_t
StrideB
,
index_t
StrideC
)
override
{
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
GridwiseGemm
::
CalculateMPadded
(
M
),
N
,
GridwiseGemm
::
CalculateNPadded
(
N
),
StrideC
);
return
2
*
sizeof
(
CDataType
)
*
c_grid_desc_m_n
.
GetElementSpaceSize
();
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
}
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
f4ea00fc
...
...
@@ -170,7 +170,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
...
@@ -252,7 +252,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
...
...
@@ -387,10 +387,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
));
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
{
...
...
@@ -419,7 +415,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
}
__host__
__device__
void
Print
()
const
__host__
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
...
...
@@ -435,10 +431,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<<
"BK0:"
<<
BK0
<<
"}"
<<
std
::
endl
;
}
__host__
__device__
Argument
(
const
Argument
&
)
=
default
;
__host__
__device__
~
Argument
()
override
{}
index_t
M
;
index_t
N
;
index_t
K
;
...
...
@@ -456,7 +448,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
...
...
@@ -464,7 +456,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1Number
,
AK1Number
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
...
...
@@ -472,8 +464,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1Number
,
BK1Number
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
...
...
@@ -488,7 +479,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
...
...
@@ -516,7 +507,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
...
@@ -601,7 +592,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
...
...
@@ -609,7 +600,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
template
<
typename
CGridDesc
>
__host__
__device__
static
constexpr
auto
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
...
...
@@ -631,33 +622,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// return block_id to C matrix tile idx (m0, n0) mapping
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
__host__
__device__
static
void
print_bytes
(
const
uint8_t
*
memory
,
std
::
size_t
size
)
{
(
void
)
memory
;
(
void
)
size
;
for
(
std
::
size_t
idx
=
0
;
idx
<
size
;
++
idx
)
{
if
(
idx
%
10
==
0
)
{
printf
(
"
\n
"
);
}
printf
(
"0x%02X "
,
static_cast
<
unsigned
>
(
memory
[
idx
]));
}
printf
(
"
\n
"
);
}
template
<
typename
T
>
__host__
__device__
static
void
print_bytes
(
const
T
&
obj
)
{
uint8_t
memory
[
sizeof
(
T
)];
memcpy
(
memory
,
&
obj
,
sizeof
(
T
));
print_bytes
(
memory
,
sizeof
(
T
));
}
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment