Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
610f9a34
Commit
610f9a34
authored
Feb 12, 2025
by
Sudhir Kylasa
Browse files
Addressing code review comments.
parent
a3678d26
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
22 additions
and
27 deletions
+22
-27
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+1
-1
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+4
-4
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+12
-17
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+1
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+4
-4
No files found.
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
610f9a34
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_MEMORY
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_MEMORY
#endif
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_
COMPUTE
)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_
MEMORY
)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
610f9a34
...
@@ -94,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -94,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
using
ADataType
=
typename
Gemm
Basic
TypeConfig
<
PrecType
>::
ADataType
;
using
ADataType
=
typename
GemmTypeConfig
<
PrecType
>::
ADataType
;
using
BDataType
=
typename
Gemm
Basic
TypeConfig
<
PrecType
>::
BDataType
;
using
BDataType
=
typename
GemmTypeConfig
<
PrecType
>::
BDataType
;
using
CDataType
=
typename
Gemm
Basic
TypeConfig
<
PrecType
>::
CDataType
;
using
CDataType
=
typename
GemmTypeConfig
<
PrecType
>::
CDataType
;
using
AccDataType
=
typename
Gemm
Basic
TypeConfig
<
PrecType
>::
AccDataType
;
using
AccDataType
=
typename
GemmTypeConfig
<
PrecType
>::
AccDataType
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
610f9a34
...
@@ -240,8 +240,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -240,8 +240,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc"
#include "run_gemm_example.inc"
void
run_gemm_
instance
(
std
::
string
data_type
,
std
::
string
a_layout
,
std
::
string
b_layout
)
int
run_gemm_
example
(
int
argc
,
char
*
argv
[]
)
{
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
{
if
(
data_type
==
"fp16"
)
if
(
data_type
==
"fp16"
)
...
@@ -340,20 +351,4 @@ void run_gemm_instance(std::string data_type, std::string a_layout, std::string
...
@@ -340,20 +351,4 @@ void run_gemm_instance(std::string data_type, std::string a_layout, std::string
}
}
}
}
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
return
run_gemm_instance
(
data_type
,
a_layout
,
b_layout
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
610f9a34
...
@@ -142,7 +142,7 @@ struct CShuffleEpilogue
...
@@ -142,7 +142,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D
<
kBlockSize
,
TileDistributionEncodingPattern2D
<
kBlockSize
,
kMPerIteration
,
kMPerIteration
,
kNPerIteration
,
kNPerIteration
,
GetVectorSizeC
<
ODataType
>
(),
GetVectorSizeC
(),
tile_distribution_pattern
::
thread_raked
>
;
tile_distribution_pattern
::
thread_raked
>
;
constexpr
auto
dram_tile_distribution
=
TileEncodingPattern
::
Make2DStaticTileDistribution
();
constexpr
auto
dram_tile_distribution
=
TileEncodingPattern
::
Make2DStaticTileDistribution
();
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
610f9a34
...
@@ -248,7 +248,7 @@ struct GemmKernel
...
@@ -248,7 +248,7 @@ struct GemmKernel
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
kargs
.
N
%
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
!=
0
)
if
(
kargs
.
N
%
EpiloguePipeline
::
GetVectorSizeC
()
!=
0
)
{
{
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
return
false
;
...
@@ -263,7 +263,7 @@ struct GemmKernel
...
@@ -263,7 +263,7 @@ struct GemmKernel
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
kargs
.
M
%
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
!=
0
)
if
(
kargs
.
M
%
EpiloguePipeline
::
GetVectorSizeC
()
!=
0
)
{
{
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
return
false
;
...
@@ -329,7 +329,7 @@ struct GemmKernel
...
@@ -329,7 +329,7 @@ struct GemmKernel
c_ptr
,
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
>
{},
number
<
EpiloguePipeline
::
GetVectorSizeC
()
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
else
else
...
@@ -527,7 +527,7 @@ struct GemmKernel
...
@@ -527,7 +527,7 @@ struct GemmKernel
{
{
// Do not compile in case where we have unsupported
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
// VectorSizeC & data type configuration.
if
constexpr
(
!
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
%
2
!=
0
&&
if
constexpr
(
!
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
{
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment