Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
987cc54d
Commit
987cc54d
authored
Feb 04, 2025
by
ThomasNing
Browse files
Finish the integration to develop and have the correct result
parent
3b301468
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
66 deletions
+62
-66
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+48
-50
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+9
-11
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
+5
-5
No files found.
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
987cc54d
...
@@ -114,8 +114,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -114,8 +114,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>
;
tail_number_v
>
;
using
GemmPipeline
=
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
>
;
GEMM_PIPELINE
<
UniversalGemmProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CDataType
,
...
@@ -241,64 +240,63 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -241,64 +240,63 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
}
#endif
#endif
}
else
{
// Tail number always Full - #PrefetchStages
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
false
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
else
{
{
// Tail number always Full - #PrefetchStages
std
::
ostringstream
err
;
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
err
<<
"When there's no hot loop, this tail number
\"
"
<<
tail_num
{
<<
"
\"
is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
Run
(
ck_tile
::
bool_constant
<
false
>
{},
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
throw
std
::
runtime_error
(
err
.
str
());
}
else
{
std
::
ostringstream
err
;
err
<<
"When there's no hot loop, this tail number
\"
"
<<
tail_num
<<
"
\"
is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
}
return
ave_time
;
}
}
return
ave_time
;
}
#include "run_gemm_example.inc"
#include "run_gemm_example.inc"
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
}
else
else
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
987cc54d
...
@@ -490,8 +490,6 @@ struct GemmKernel
...
@@ -490,8 +490,6 @@ struct GemmKernel
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr_0
);
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr_0
);
// Run Epilogue Pipeline
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
...
@@ -548,7 +546,7 @@ struct GemmKernel
...
@@ -548,7 +546,7 @@ struct GemmKernel
EpiloguePipeline
{}
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
,
smem_ptr_0
,
smem_ptr_1
);
c_block_window
,
c_block_tile
,
smem_ptr_0
);
}
}
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
...
@@ -596,14 +594,14 @@ struct GemmKernel
...
@@ -596,14 +594,14 @@ struct GemmKernel
if
constexpr
(
GemmPipeline
::
DoubleSmemBuffer
==
true
)
if
constexpr
(
GemmPipeline
::
DoubleSmemBuffer
==
true
)
{
{
RunGemm2LDS
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
RunGemm2LDS
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
b_ptr
,
c_ptr
,
c_ptr
,
smem_ptr_0
,
smem_ptr_0
,
smem_ptr_1
,
smem_ptr_1
,
kargs
,
kargs
,
splitk_batch_offset
,
splitk_batch_offset
,
i_m
,
i_m
,
i_n
);
i_n
);
}
}
else
else
{
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
View file @
987cc54d
...
@@ -69,9 +69,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
...
@@ -69,9 +69,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Policy
::
template
GetVectorSizeA
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
VectorSizeB
=
Policy
::
template
GetVectorSizeB
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
VectorSizeC
=
Policy
::
template
GetVectorSizeC
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
...
@@ -117,9 +117,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
...
@@ -117,9 +117,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
A_Buffer_Load_Inst_Num
=
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeA
);
MPerBlock
*
KPerBlock
/
(
BlockSize
*
Get
VectorSizeA
()
);
constexpr
index_t
B_Buffer_Load_Inst_Num
=
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeB
);
NPerBlock
*
KPerBlock
/
(
BlockSize
*
Get
VectorSizeB
()
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment