Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
71352c44
Commit
71352c44
authored
Jan 30, 2025
by
ThomasNing
Browse files
Solve the compiler issue on SHMEM conflict
parent
49316982
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
113 additions
and
41 deletions
+113
-41
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+11
-11
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+1
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+101
-26
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
+0
-3
No files found.
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
71352c44
...
@@ -216,14 +216,14 @@ int run_gemm_example(int argc, char* argv[])
...
@@ -216,14 +216,14 @@ int run_gemm_example(int argc, char* argv[])
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
//
if(a_layout == "R" && b_layout == "R")
{
//
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
//
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
//
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
//
else if(a_layout == "R" && b_layout == "C")
{
//
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
//
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work.
// work.
// else if(a_layout == "C" && b_layout == "C")
// else if(a_layout == "C" && b_layout == "C")
...
@@ -234,8 +234,8 @@ int run_gemm_example(int argc, char* argv[])
...
@@ -234,8 +234,8 @@ int run_gemm_example(int argc, char* argv[])
// {
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
// }
else
//
else
{
//
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
//
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
//
}
}
}
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
71352c44
...
@@ -33,7 +33,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -33,7 +33,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// Compute friendly for Intrawave scheduler
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
71352c44
...
@@ -414,10 +414,10 @@ struct GemmKernel
...
@@ -414,10 +414,10 @@ struct GemmKernel
* @tparam DstInMemOp Destination memory operation (default: set).
* @tparam DstInMemOp Destination memory operation (default: set).
*/
*/
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
static
void
RunGemm
(
const
ADataType
*
a_ptr
,
CK_TILE_DEVICE
static
void
RunGemm
SinglePointer
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
CDataType
*
c_ptr
,
void
*
smem_ptr
,
void
*
smem_ptr
_0
,
const
GemmKernelArgs
&
kargs
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
const
index_t
block_idx_m
,
...
@@ -436,20 +436,63 @@ struct GemmKernel
...
@@ -436,20 +436,63 @@ struct GemmKernel
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
[
&
]()
{
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr_0
);
if
constexpr
(
GemmPipeline
::
isDoubleSmemBuffer
==
true
)
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
((
DstInMemOp
==
memory_operation_enum
::
set
)
||
(
sizeof
(
CDataType
)
>
2
)
||
(
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
))
{
{
__shared__
char
smem_ptr_1
[
GetSmemSize
()];
EpiloguePipeline
{}
return
GemmPipeline
{}.
template
operator
()
(
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>
(
a
_block_window
,
b
_block_
window
,
num_loop
,
smem_ptr
,
smem_ptr_1
);
c
_block_window
,
c
_block_
tile
);
}
}
else
{
return
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
}
}
}();
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
static
void
RunGemmDoublePointer
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
void
*
smem_ptr_0
,
void
*
smem_ptr_1
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
{
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
;
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
splitk_batch_offset
.
splitted_k
);
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr_0
,
smem_ptr_1
);
// Run Epilogue Pipeline
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
...
@@ -479,16 +522,48 @@ struct GemmKernel
...
@@ -479,16 +522,48 @@ struct GemmKernel
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
// allocate LDS
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
__shared__
char
smem_ptr_0
[
GetSmemSize
()];
__shared__
char
smem_ptr_1
[
GetSmemSize
()];
if
(
kargs
.
KBatch
==
1
)
if
(
kargs
.
KBatch
==
1
)
{
{
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
if
constexpr
(
GemmPipeline
::
isDoubleSmemBuffer
==
true
)
{
RunGemmDoublePointer
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
smem_ptr_1
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemmSinglePointer
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
}
else
else
{
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
if
constexpr
(
GemmPipeline
::
isDoubleSmemBuffer
==
true
)
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
{
RunGemmDoublePointer
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
smem_ptr_1
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemmSinglePointer
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
}
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
View file @
71352c44
...
@@ -268,9 +268,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -268,9 +268,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
printf
(
"Tail Num: =====================================
\n
"
);
printf
(
"%d
\n
"
,
static_cast
<
int
>
(
TailNum
));
if
(
HasHotLoop
)
if
(
HasHotLoop
)
{
{
// minus 2 because we have ping-pong double buffer.
// minus 2 because we have ping-pong double buffer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment