Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d0c56016
Commit
d0c56016
authored
Feb 10, 2025
by
illsilin
Browse files
fix clang format
parent
f1e53807
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
41 deletions
+42
-41
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
...gen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
+4
-5
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+38
-36
No files found.
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
View file @
d0c56016
...
@@ -62,12 +62,11 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
...
@@ -62,12 +62,11 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major
// accounts for all possible combinations of Row/Col major
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
// clang-format off
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// clang-format on
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
// instances
...
@@ -251,8 +250,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -251,8 +250,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
x
.
tile_desc
.
m_per_block
,
x
.
tile_desc
.
m_per_block
,
x
.
tile_desc
.
n_per_block
,
x
.
tile_desc
.
n_per_block
,
x
.
tile_desc
.
k_per_block
);
x
.
tile_desc
.
k_per_block
);
x
.
loop_scheduler
=
loop_scheduler
;
x
.
loop_scheduler
=
loop_scheduler
;
x
.
pipeline_version
=
pipeline_version
;
x
.
pipeline_version
=
pipeline_version
;
x
.
update_prologue
(
prologue
);
x
.
update_prologue
(
prologue
);
x
.
update_epilogue
(
epilogue
);
x
.
update_epilogue
(
epilogue
);
result
.
push_back
(
x
);
result
.
push_back
(
x
);
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
d0c56016
...
@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
...
@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
template
<
typename
ADataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
...
@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args
.
stride_B
=
stride_B
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
args
.
stride_C
=
stride_C
;
float
ave_time
=
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
float
ave_time
=
ALayout
,
BLayout
,
CLayout
>
(
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
std
::
size_t
num_byte
=
...
@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" A_Layout ="
<<
ALayout
::
name
<<
" A_Layout ="
<<
ALayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
return
ave_time
;
return
ave_time
;
}
}
...
@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
...
@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
a_m_k_dev_buf
,
b_k_n_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
c_m_n_dev_buf
,
M
,
M
,
N
,
N
,
K
,
K
,
stride_A
,
stride_A
,
stride_B
,
stride_B
,
stride_C
,
stride_C
,
kbatch
,
kbatch
,
n_warmup
,
n_warmup
,
n_repeat
);
n_repeat
);
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
bool
pass
=
true
;
...
@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc,
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
(
K
,
kbatch
,
max_accumulated_value
);
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
...
@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
const
float
max_accumulated_value
=
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
(
K
,
kbatch
,
max_accumulated_value
);
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment