Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
854ccaa5
Commit
854ccaa5
authored
May 06, 2023
by
Po-Yen, Chen
Browse files
Always create B grid descriptor on device side
parent
04f6c31c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
32 deletions
+35
-32
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+13
-13
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+22
-19
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
854ccaa5
...
@@ -257,11 +257,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -257,11 +257,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
MPadded
{
GridwiseGemm
::
CalculateMPadded
(
M_
)},
MPadded
{
GridwiseGemm
::
CalculateMPadded
(
M_
)},
NPadded
{
GridwiseGemm
::
CalculateNPadded
(
N_
)},
NPadded
{
GridwiseGemm
::
CalculateNPadded
(
N_
)},
a_grid_desc_k0_m_k1
{},
a_grid_desc_k0_m_k1
{},
b_grid_desc_k0_n_k1
{},
c_grid_desc_m_n
{}
c_grid_desc_m_n
{}
{
{
a_grid_desc_k0_m_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M_K1
(
M
,
MPadded
,
K
,
StrideA
);
a_grid_desc_k0_m_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M_K1
(
M
,
MPadded
,
K
,
StrideA
);
b_grid_desc_k0_n_k1
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
NPadded
,
StrideB
);
c_grid_desc_m_n
=
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
MPadded
,
N
,
NPadded
,
StrideC
);
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
MPadded
,
N
,
NPadded
,
StrideC
);
}
}
...
@@ -279,7 +277,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -279,7 +277,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t
MPadded
;
index_t
MPadded
;
index_t
NPadded
;
index_t
NPadded
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
;
CGridDesc_M_N
c_grid_desc_m_n
;
CGridDesc_M_N
c_grid_desc_m_n
;
};
};
...
@@ -292,16 +289,19 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -292,16 +289,19 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
{
{
#if DEBUG_LOG
#if DEBUG_LOG
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
// std::cout << "arg.a_grid_desc_k0_m_k1_{" <<
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
// arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
// << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
// << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
// std::cout << "arg.b_grid_desc_k0_n_k1_{" <<
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
// arg.b_grid_desc_k0_n_k1_.GetLength(I0)
// << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
// << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
}
#endif
#endif
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
854ccaa5
...
@@ -312,26 +312,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -312,26 +312,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
(
NPerBlock
%
(
NXdlPerWave
*
NPerXDL
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
karg
.
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
(
void
)
karg
;
const
auto
N
=
karg
.
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
return
true
;
const
auto
K0
=
karg
.
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
if
(
!
(
M
==
karg
.
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
karg
.
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
// const auto M = karg.a_grid_desc_k0_m_k1.GetLength(I1);
K0
==
karg
.
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
// const auto N = karg.b_grid_desc_k0_n_k1.GetLength(I1);
K1
==
karg
.
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
// const auto K0 = karg.a_grid_desc_k0_m_k1.GetLength(I0);
K1
==
karg
.
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
// if(!(M == karg.c_grid_desc_m_n.GetLength(I0) && N == karg.c_grid_desc_m_n.GetLength(I1)
return
false
;
// &&
// K0 == karg.b_grid_desc_k0_n_k1.GetLength(I0) &&
// K1 == karg.a_grid_desc_k0_m_k1.GetLength(I2) &&
// K1 == karg.b_grid_desc_k0_n_k1.GetLength(I2)))
// return false;
//
check gridwise gemm pipeline
//
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
// return false
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
// // check gridwise gemm pipeline
{
// const auto num_k_loop = K0 / K0PerBlock;
return
false
;
}
// if(!GridwiseGemmPipe::IsSupported(num_k_loop))
// {
// return false;
// }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
...
@@ -408,16 +412,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -408,16 +412,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
#define CREATE_DESC_ON_HOST 1
#define CREATE_DESC_ON_HOST 1
#if CREATE_DESC_ON_HOST
#if CREATE_DESC_ON_HOST
const
auto
a_grid_desc_k0_m_k1
=
karg
.
a_grid_desc_k0_m_k1
;
const
auto
a_grid_desc_k0_m_k1
=
karg
.
a_grid_desc_k0_m_k1
;
const
auto
b_grid_desc_k0_n_k1
=
karg
.
b_grid_desc_k0_n_k1
;
const
auto
c_grid_desc_m_n
=
karg
.
c_grid_desc_m_n
;
const
auto
c_grid_desc_m_n
=
karg
.
c_grid_desc_m_n
;
#else
#else
const
auto
a_grid_desc_k0_m_k1
=
const
auto
a_grid_desc_k0_m_k1
=
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
);
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
);
const
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideB
);
const
auto
c_grid_desc_m_n
=
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
);
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
);
#endif
#endif
const
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideB
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment