Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0cf90eaf
"torchvision/vscode:/vscode.git/clone" did not exist on "eafab6bf31b733ba4a644e765f2bbf85dce5cd2b"
Commit
0cf90eaf
authored
May 04, 2023
by
Po-Yen, Chen
Browse files
Remove unnesscary type parameters
parent
670ce6b9
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
472 additions
and
118 deletions
+472
-118
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+100
-57
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+372
-61
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
0cf90eaf
...
...
@@ -299,6 +299,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
...
...
@@ -308,9 +311,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation
,
GemmSpec
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -363,29 +363,89 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
GridwiseGemm
::
CalculateMPadded
(
M_
)},
NPadded
{
GridwiseGemm
::
CalculateNPadded
(
N_
)},
KPadded
{
GridwiseGemm
::
CalculateKPadded
(
K_
)},
AK0
{
GridwiseGemm
::
CalculateAK0
(
K_
)},
BK0
{
GridwiseGemm
::
CalculateBK0
(
K_
)},
a_grid_desc_ak0_m_ak1
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
M_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
K_
,
GridwiseGemm
::
CalculateKPadded
(
K_
),
StrideA_
,
GridwiseGemm
::
CalculateAK0
(
K_
))},
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
M_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
K_
,
GridwiseGemm
::
CalculateKPadded
(
K_
),
StrideA_
,
GridwiseGemm
::
CalculateAK0
(
K_
))},
b_grid_desc_bk0_n_bk1
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
K_
,
GridwiseGemm
::
CalculateKPadded
(
K_
),
N_
,
GridwiseGemm
::
CalculateNPadded
(
N_
),
StrideB_
,
GridwiseGemm
::
CalculateBK0
(
K_
))},
c_grid_desc_m_n
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
M_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
N_
,
GridwiseGemm
::
CalculateNPadded
(
N_
),
StrideC_
)}
GridwiseGemm
::
MakeBGridDescriptor_BK0_N_BK1
(
K_
,
GridwiseGemm
::
CalculateKPadded
(
K_
),
N_
,
GridwiseGemm
::
CalculateNPadded
(
N_
),
StrideB_
,
GridwiseGemm
::
CalculateBK0
(
K_
))},
c_grid_desc_m_n
{
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
N_
,
GridwiseGemm
::
CalculateNPadded
(
N_
),
StrideC_
)}
{
}
__host__
__device__
Argument
(
const
Argument
&
)
=
default
;
__host__
__device__
void
Print
()
const
{
printf
(
"arg {M: %d, N: %d, K: %d, "
"SA: %d, SB: %d, SC: %d, "
"MP: %d, NP: %d, KP: %d, "
"AK0: %d, BK0: %d}
\n
"
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
,
KPadded
,
AK0
,
BK0
);
// std::cout << "arg {"
// << "M:" << M << ", "
// << "N:" << N << ", "
// << "K:" << K << ", "
// << "SA:" << StrideA << ", "
// << "SB:" << StrideB << ", "
// << "SC:" << StrideC << ", "
// << "MP:" << MPadded << ", "
// << "NP:" << NPadded << ", "
// << "KP:" << KPadded << ", "
// << "AK0:" << AK0 << ", "
// << "BK0:" << BK0 << "}" << std::endl;
}
__host__
__device__
Argument
(
const
Argument
&
other
)
:
p_a_grid
{
other
.
p_a_grid
},
p_b_grid
{
other
.
p_b_grid
},
p_c_grid
{
other
.
p_c_grid
},
M
{
other
.
M
},
N
{
other
.
N
},
K
{
other
.
K
},
StrideA
{
other
.
StrideA
},
StrideB
{
other
.
StrideB
},
StrideC
{
other
.
StrideC
},
MPadded
{
other
.
MPadded
},
NPadded
{
other
.
NPadded
},
KPadded
{
other
.
KPadded
},
AK0
{
other
.
AK0
},
BK0
{
other
.
BK0
},
a_grid_desc_ak0_m_ak1
{
other
.
a_grid_desc_ak0_m_ak1
},
b_grid_desc_bk0_n_bk1
{
other
.
b_grid_desc_bk0_n_bk1
},
c_grid_desc_m_n
{
other
.
c_grid_desc_m_n
}
{
}
__host__
__device__
~
Argument
()
override
{}
...
...
@@ -396,6 +456,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
...
...
@@ -406,28 +474,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
using
Argument
=
DeviceOp
::
Argument
;
void
Print
(
const
Argument
&
karg
)
{
karg
.
Print
();
}
float
Run
(
const
Argument
&
karg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
if
(
stream_config
.
log_level_
>
0
)
{
// std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
Print
(
karg
);
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
.
a_grid_desc_ak0_m_ak1
,
karg
.
b_grid_desc_bk0_n_bk1
,
karg
.
c_grid_desc_m_n
))
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
...
...
@@ -441,15 +497,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
Argument
,
true
>
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
Argument
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
Argument
,
false
>
;
ave_time
=
launch_and_time_kernel
(
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
Argument
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
}
...
...
@@ -472,22 +530,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
karg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
if
((
karg
.
K
%
AK1
!=
0
||
karg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
karg
.
a_grid_desc_ak0_m_ak1
,
karg
.
b_grid_desc_bk0_n_bk1
,
karg
.
c_grid_desc_m_n
);
return
GridwiseGemm
::
CheckValidity
(
karg
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
0cf90eaf
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment