Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
611064a1
Commit
611064a1
authored
Oct 09, 2024
by
Adam Osewski
Browse files
Do not use macro.
parent
41fc6a24
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
44 deletions
+51
-44
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
+51
-44
No files found.
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
View file @
611064a1
...
@@ -84,20 +84,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -84,20 +84,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
float
ave_time
{
0
};
float
ave_time
{
0
};
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
using
GemmKernel
=
ck_tile
::
remove_cvref_t
<
decltype
(
kernel
)
>
;
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
auto
kargs
=
GemmKernel
::
MakeKargs
(
args
.
p_a
,
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
args
.
p_b
,
args
.
p_c
,
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
args
.
M
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
args
.
N
,
BDataType
,
args
.
K
,
CDataType
,
args
.
stride_A
,
GemmShape
,
args
.
stride_B
,
ALayout
,
args
.
stride_C
);
BLayout
,
CLayout
,
const
dim3
grids
=
GemmKernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
kPadA
,
constexpr
dim3
blocks
=
GemmKernel
::
BlockSize
();
kPadB
,
kPadC
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
{
{
...
@@ -108,79 +125,70 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -108,79 +125,70 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
}
}
ave_time
=
ck_tile
::
launch_kernel
(
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
kernel
,
grids
,
blocks
,
0
,
kargs
));
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
};
};
#define RUN_KERNEL_(has_hot_loop_, tail_number_) \
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< \
ck_tile::UniversalGemmPipelineProblem<ADataType, \
BDataType, \
CDataType, \
GemmShape, \
ALayout, \
BLayout, \
CLayout, \
kPadA, \
kPadB, \
kPadC, \
ck_tile::GemmPipelineScheduler::Intrawave, \
has_hot_loop_, \
tail_number_>>; \
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; \
Run(Kernel{});
if
(
has_hot_loop
)
if
(
has_hot_loop
)
{
{
// Tail pipeline One to Seven
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
One
);
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Full
);
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
{
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Two
);
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
{
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Three
);
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
{
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
{
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Four
);
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Four
>
{});
}
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Five
);
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Five
>
{});
}
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
{
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Six
);
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Six
>
{});
}
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
{
RUN_KERNEL_
(
true
,
ck_tile
::
TailNumber
::
Seven
);
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
}
}
}
...
@@ -189,12 +197,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -189,12 +197,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Tail number always 1
// Tail number always 1
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
{
RUN_KERNEL_
(
false
,
ck_tile
::
TailNumber
::
One
);
Run
(
ck_tile
::
bool_constant
<
false
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
}
}
}
#undef RUN_KERNEL_
return
ave_time
;
return
ave_time
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment