Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cbc49dc2
Commit
cbc49dc2
authored
May 09, 2023
by
Po-Yen, Chen
Browse files
Completely move descriptor-creation logic on device side
parent
d57f9521
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
34 deletions
+12
-34
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+2
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+10
-22
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
cbc49dc2
...
...
@@ -146,18 +146,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
GridwiseGemm
::
CalculateMPadded
(
M_
)},
NPadded
{
GridwiseGemm
::
CalculateNPadded
(
N_
)},
a_grid_desc_k0_m_k1
{},
c_grid_desc_m_n
{}
NPadded
{
GridwiseGemm
::
CalculateNPadded
(
N_
)}
{
// Print();
a_grid_desc_k0_m_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M_K1
(
M
,
MPadded
,
K
,
StrideA
);
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
MPadded
,
N
,
NPadded
,
StrideC
);
}
__host__
__device__
void
Print
()
const
__host__
void
Print
()
const
{
printf
(
"M = %d, N = %d, K = %d, "
"SA = %d, SB = %d, SC = %d, "
...
...
@@ -172,7 +165,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
NPadded
);
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
...
...
@@ -184,8 +176,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
;
CGridDesc_M_N
c_grid_desc_m_n
;
};
// Invoker
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
cbc49dc2
...
...
@@ -100,11 +100,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
#if defined(INTEGER_DIVIDE_CEIL)
#error "macro INTEGER_DIVIDE_CEIL() was already defined somewhere else"
#endif
#define INTEGER_DIVIDE_CEIL(x, y) (((x) + (y)-1) / (y))
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
1
);
...
...
@@ -112,16 +107,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
INTEGER_DIVIDE_CEIL
(
M
,
MPerBlock
)
*
MPerBlock
;
return
math
::
integer_divide_ceil
(
M
,
MPerBlock
)
*
MPerBlock
;
}
__host__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
INTEGER_DIVIDE_CEIL
(
N
,
NPerBlock
)
*
NPerBlock
;
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
)
*
NPerBlock
;
}
#undef INTEGER_DIVIDE_CEIL
__host__
__device__
static
auto
__device__
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
)
{
const
index_t
K0
=
K
/
K1
;
...
...
@@ -157,7 +151,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}
__host__
__device__
static
auto
__device__
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
)
{
const
index_t
K0
=
K
/
K1
;
...
...
@@ -193,7 +187,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}
__host__
__device__
static
auto
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
...
...
@@ -403,24 +397,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Argument
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
void
*
__restrict__
p_shared
,
const
Argument
&
karg
)
{
#define CREATE_DESC_ON_HOST 1
#if CREATE_DESC_ON_HOST
const
auto
a_grid_desc_k0_m_k1
=
karg
.
a_grid_desc_k0_m_k1
;
const
auto
c_grid_desc_m_n
=
karg
.
c_grid_desc_m_n
;
#else
const
auto
a_grid_desc_k0_m_k1
=
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
);
#endif
const
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideB
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment