Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1b78ca0d
Commit
1b78ca0d
authored
May 09, 2023
by
Po-Yen, Chen
Browse files
Move 'Argument' into GridwiseGemm
parent
cbc49dc2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
69 deletions
+63
-69
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+9
-67
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+54
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
1b78ca0d
...
@@ -120,69 +120,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -120,69 +120,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
LoopSched
,
LoopSched
,
PipelineVer
>
;
PipelineVer
>
;
using
AGridDesc_K0_M_K1
=
decltype
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
,
1
));
using
Argument
=
typename
GridwiseGemm
::
Argument
;
using
BGridDesc_K0_N_K1
=
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
));
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
GridwiseGemm
::
CalculateMPadded
(
M_
)},
NPadded
{
GridwiseGemm
::
CalculateNPadded
(
N_
)}
{
}
__host__
void
Print
()
const
{
printf
(
"M = %d, N = %d, K = %d, "
"SA = %d, SB = %d, SC = %d, "
"MP = %d, NP = %d
\n
"
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
);
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
};
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
using
Argument
=
DeviceGemmXdl
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if DEBUG_LOG
#if DEBUG_LOG
...
@@ -216,30 +158,30 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -216,30 +158,30 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
arg
.
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
arg
.
K
))
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
Argument
,
true
>
;
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid
_
,
arg
.
p_a_grid
,
arg
.
p_b_grid
_
,
arg
.
p_b_grid
,
arg
.
p_c_grid
_
,
arg
.
p_c_grid
,
arg
);
arg
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
Argument
,
false
>
;
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid
_
,
arg
.
p_a_grid
,
arg
.
p_b_grid
_
,
arg
.
p_b_grid
,
arg
.
p_c_grid
_
,
arg
.
p_c_grid
,
arg
);
arg
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
1b78ca0d
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
Argument
,
bool
HasMainKBlockLoop
>
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -24,7 +24,7 @@ __global__ void
...
@@ -24,7 +24,7 @@ __global__ void
kernel_gemm_xdlops_v2r3
(
const
typename
GridwiseGemm
::
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_xdlops_v2r3
(
const
typename
GridwiseGemm
::
FloatAB
*
__restrict__
p_a_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
__restrict__
p_b_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
__restrict__
p_b_grid
,
typename
GridwiseGemm
::
FloatC
*
__restrict__
p_c_grid
,
typename
GridwiseGemm
::
FloatC
*
__restrict__
p_c_grid
,
const
Argument
karg
)
const
typename
GridwiseGemm
::
Argument
karg
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
defined(__gfx940__))
...
@@ -220,6 +220,58 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -220,6 +220,58 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}
}
}
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
{
__host__
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)}
{
}
__host__
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
"}"
<<
std
::
endl
;
}
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
};
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment