Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6352deaf
"git@developer.sourcefind.cn:gaoqiong/yaml-cpp.git" did not exist on "979a91692f7c52dcaa52066a752210c911a5ef64"
Commit
6352deaf
authored
May 09, 2023
by
Po-Yen, Chen
Browse files
Finish karg simplification work for DeviceGemmXdl<>
parent
1b78ca0d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
50 deletions
+73
-50
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+17
-30
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+56
-20
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
6352deaf
...
@@ -125,38 +125,25 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -125,38 +125,25 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
k
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if DEBUG_LOG
if
(
stream_config
.
log_level_
>
0
)
{
{
// std::cout << "arg.a_grid_desc_k0_m_k1_{" <<
karg
.
Print
();
// arg.a_grid_desc_k0_m_k1_.GetLength(I0)
// << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
// << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.b_grid_desc_k0_n_k1_{" <<
// arg.b_grid_desc_k0_n_k1_.GetLength(I0)
// << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
// << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
if
(
!
GridwiseGemm
::
CheckValidity
(
k
arg
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
}
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
k
arg
.
M
,
k
arg
.
N
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
arg
.
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
k
arg
.
K
))
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
true
>
;
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
true
>
;
...
@@ -165,10 +152,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -165,10 +152,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid
,
k
arg
.
p_a_grid
,
arg
.
p_b_grid
,
k
arg
.
p_b_grid
,
arg
.
p_c_grid
,
k
arg
.
p_c_grid
,
arg
);
k
arg
);
}
}
else
else
{
{
...
@@ -179,10 +166,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -179,10 +166,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid
,
k
arg
.
p_a_grid
,
arg
.
p_b_grid
,
k
arg
.
p_b_grid
,
arg
.
p_c_grid
,
k
arg
.
p_c_grid
,
arg
);
k
arg
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -202,7 +189,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -202,7 +189,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
k
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx908"
)
if
(
ck
::
get_device_name
()
==
"gfx908"
)
{
{
...
@@ -225,12 +212,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -225,12 +212,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return
false
;
return
false
;
}
}
if
(
arg
.
K
%
K1
!=
0
)
if
(
k
arg
.
K
%
K1
!=
0
)
{
{
return
false
;
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
);
return
GridwiseGemm
::
CheckValidity
(
k
arg
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
6352deaf
...
@@ -358,30 +358,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -358,30 +358,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
(
NPerBlock
%
(
NXdlPerWave
*
NPerXDL
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
(
void
)
karg
;
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
return
true
;
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
return
false
;
}
}
// const auto M = karg.a_grid_desc_k0_m_k1.GetLength(I1);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
// const auto N = karg.b_grid_desc_k0_n_k1.GetLength(I1);
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
// const auto K0 = karg.a_grid_desc_k0_m_k1.GetLength(I0);
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
return
false
;
}
}
// if(!(M == karg.c_grid_desc_m_n.GetLength(I0) && N == karg.c_grid_desc_m_n.GetLength(I1)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
// &&
{
// K0 == karg.b_grid_desc_k0_n_k1.GetLength(I0) &&
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
// K1 == karg.a_grid_desc_k0_m_k1.GetLength(I2) &&
{
// K1 == karg.b_grid_desc_k0_n_k1.GetLength(I2)))
return
false
;
// return false;
}
}
else
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
// if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
// return false;
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
// // check gridwise gemm pipeline
// check gridwise gemm pipeline
// const auto num_k_loop = K0 / K0PerBlock;
const
index_t
K0
=
karg
.
K
/
K1
;
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
//
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
//
{
{
//
return false;
return
false
;
//
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
...
@@ -476,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -476,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
BElementwiseOperation
b_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)
;
const
index_t
K0
=
karg
.
K
/
K1
;
const
auto
block_2_ctile_map
=
Block2CTileMap
{
karg
.
M
,
karg
.
N
};
const
auto
block_2_ctile_map
=
Block2CTileMap
{
karg
.
M
,
karg
.
N
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment