Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ef5afc55
"...composable_kernel.git" did not exist on "ad09ebdb531285c35f7c45be68db7fd52b5dc082"
Commit
ef5afc55
authored
May 04, 2023
by
Po-Yen, Chen
Browse files
Reserve kernel arg as whole object in interfaces
parent
613dcc6b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
134 deletions
+87
-134
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+72
-99
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+14
-34
No files found.
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
ef5afc55
...
@@ -18,7 +18,7 @@ struct BaseArgument
...
@@ -18,7 +18,7 @@ struct BaseArgument
BaseArgument
(
const
BaseArgument
&
)
=
default
;
BaseArgument
(
const
BaseArgument
&
)
=
default
;
BaseArgument
&
operator
=
(
const
BaseArgument
&
)
=
default
;
BaseArgument
&
operator
=
(
const
BaseArgument
&
)
=
default
;
virtual
~
BaseArgument
()
{}
__host__
__device__
virtual
~
BaseArgument
()
{}
void
*
p_workspace_
=
nullptr
;
void
*
p_workspace_
=
nullptr
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
ef5afc55
...
@@ -348,54 +348,58 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -348,54 +348,58 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
ADataType
*
p_a_grid
,
__host__
Argument
(
const
ADataType
*
p_a_grid
_
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
_
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
_
,
index_t
M
,
index_t
M
_
,
index_t
N
,
index_t
N
_
,
index_t
K
,
index_t
K
_
,
index_t
StrideA
,
index_t
StrideA
_
,
index_t
StrideB
,
index_t
StrideB
_
,
index_t
StrideC
)
index_t
StrideC
_
)
:
p_a_grid
_
{
p_a_grid
},
:
p_a_grid
{
p_a_grid
_
},
p_b_grid
_
{
p_b_grid
},
p_b_grid
{
p_b_grid
_
},
p_c_grid
_
{
p_c_grid
},
p_c_grid
{
p_c_grid
_
},
M
_
{
M
},
M
{
M
_
},
N
_
{
N
},
N
{
N
_
},
K
_
{
K
},
K
{
K
_
},
a_grid_desc_ak0_m_ak1
_
{
a_grid_desc_ak0_m_ak1
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
M
,
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
M
_
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateMPadded
(
M
_
),
K
,
K
_
,
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateKPadded
(
K
_
),
StrideA
,
StrideA
_
,
GridwiseGemm
::
CalculateAK0
(
K
))},
GridwiseGemm
::
CalculateAK0
(
K
_
))},
b_grid_desc_bk0_n_bk1
_
{
b_grid_desc_bk0_n_bk1
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
K
,
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
K
_
,
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateKPadded
(
K
_
),
N
,
N
_
,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateNPadded
(
N
_
),
StrideB
,
StrideB
_
,
GridwiseGemm
::
CalculateBK0
(
K
))},
GridwiseGemm
::
CalculateBK0
(
K
_
))},
c_grid_desc_m_n
_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
M
,
c_grid_desc_m_n
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
M
_
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateMPadded
(
M
_
),
N
,
N
_
,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateNPadded
(
N
_
),
StrideC
)},
StrideC
_
)},
kraw_
{
K
}
kraw_
{
K
_
}
{
{
}
}
__host__
__device__
Argument
(
const
Argument
&
)
=
default
;
__host__
__device__
~
Argument
()
override
{}
// private:
// private:
const
ADataType
*
p_a_grid
_
;
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
_
;
const
BDataType
*
p_b_grid
;
CDataType
*
p_c_grid
_
;
CDataType
*
p_c_grid
;
index_t
M
_
;
index_t
M
;
index_t
N
_
;
index_t
N
;
index_t
K
_
;
index_t
K
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
_
;
CGridDesc_M_N
c_grid_desc_m_n
;
index_t
kraw_
;
index_t
kraw_
;
};
};
...
@@ -408,78 +412,47 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -408,78 +412,47 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
{
#if DEBUG_LOG
#if DEBUG_LOG
{
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
// std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<<
karg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<<
karg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<<
karg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std
::
cout
<<
"arg.b_grid_desc_bk0_n_bk1_{"
// std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<<
karg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
<<
", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<<
karg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<<
karg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
<<
karg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
// "
// << karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
}
#endif
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
.
a_grid_desc_ak0_m_ak1
_
,
karg
.
b_grid_desc_bk0_n_bk1
_
,
karg
.
c_grid_desc_m_n
_
))
karg
.
a_grid_desc_ak0_m_ak1
,
karg
.
b_grid_desc_bk0_n_bk1
,
karg
.
c_grid_desc_m_n
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
karg
.
M
_
,
karg
.
N
_
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
);
const
auto
K
=
GridwiseGemm
::
CalculateAK0
(
karg
.
K
_
)
*
AK1
;
const
auto
K
=
GridwiseGemm
::
CalculateAK0
(
karg
.
K
)
*
AK1
;
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
Argument
,
true
>
;
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ave_time
=
launch_and_time_kernel
(
CDataType
,
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
CGridDesc_M_N
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_
,
karg
.
p_b_grid_
,
karg
.
p_c_grid_
,
karg
.
a_grid_desc_ak0_m_ak1_
,
karg
.
b_grid_desc_bk0_n_bk1_
,
karg
.
c_grid_desc_m_n_
);
}
}
else
else
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
Argument
,
false
>
;
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ave_time
=
launch_and_time_kernel
(
ADataType
,
// TODO: distiguish A/B datatype
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
CDataType
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
CGridDesc_M_N
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_
,
karg
.
p_b_grid_
,
karg
.
p_c_grid_
,
karg
.
a_grid_desc_ak0_m_ak1_
,
karg
.
b_grid_desc_bk0_n_bk1_
,
karg
.
c_grid_desc_m_n_
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -516,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -516,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
return
GridwiseGemm
::
CheckValidity
(
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1
_
,
arg
.
b_grid_desc_bk0_n_bk1
_
,
arg
.
c_grid_desc_m_n
_
);
arg
.
a_grid_desc_ak0_m_ak1
,
arg
.
b_grid_desc_bk0_n_bk1
,
arg
.
c_grid_desc_m_n
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
ef5afc55
...
@@ -17,41 +17,19 @@
...
@@ -17,41 +17,19 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
Argument
,
bool
HasMainKBlockLoop
>
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_xdl_cshuffle_v1
(
const
Argument
karg
)
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
c_grid_desc_m_n
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
karg
,
p_shared
);
p_b_grid
,
p_c_grid
,
p_shared
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
karg
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_m_n
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -322,15 +300,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -322,15 +300,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
typename
Argument
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
Argument
karg
,
void
*
__restrict__
p_shared
)
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
c_grid_desc_m_n
)
{
{
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
&
a_grid_desc_ak0_m_ak1
=
karg
.
a_grid_desc_ak0_m_ak1
;
const
auto
&
b_grid_desc_bk0_n_bk1
=
karg
.
b_grid_desc_bk0_n_bk1
;
const
auto
&
c_grid_desc_m_n
=
karg
.
c_grid_desc_m_n
;
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment