Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8ce41034
Commit
8ce41034
authored
Feb 08, 2025
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into update_cka8w8_uc
parents
730b98e1
a8c5bd9b
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
615 additions
and
113 deletions
+615
-113
.azuredevops/rocm-ci.yml
.azuredevops/rocm-ci.yml
+1
-0
example/67_gemm_microscaling/gemm_mx_common.hpp
example/67_gemm_microscaling/gemm_mx_common.hpp
+22
-34
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+32
-6
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+44
-7
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+23
-6
example/ck_tile/03_gemm/script/benchmark_basic.sh
example/ck_tile/03_gemm/script/benchmark_basic.sh
+2
-1
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
+0
-0
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
+0
-0
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
+14
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
+3
-3
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
...ple/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
+13
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
+13
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
+13
-0
example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_basic.sh
+18
-17
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
+18
-17
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+91
-8
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+4
-1
include/ck/utility/blkgemmpipe_scheduler.hpp
include/ck/utility/blkgemmpipe_scheduler.hpp
+10
-2
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
+293
-10
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-1
No files found.
.azuredevops/rocm-ci.yml
View file @
8ce41034
...
@@ -14,6 +14,7 @@ trigger:
...
@@ -14,6 +14,7 @@ trigger:
branches
:
branches
:
include
:
include
:
-
develop
-
develop
-
amd-develop
paths
:
paths
:
exclude
:
exclude
:
-
.github
-
.github
...
...
example/67_gemm_microscaling/gemm_mx_common.hpp
View file @
8ce41034
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_
mx_
gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/fill.hpp"
...
@@ -315,40 +315,27 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -315,40 +315,27 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"Computing GEMM on host..."
<<
std
::
endl
;
std
::
cout
<<
"Computing GEMM on host..."
<<
std
::
endl
;
}
}
Tensor
<
CDataType
>
c
({
M
,
N
});
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceMXGemm
<
ADataType
,
Tensor
<
float
>
a
({
M
,
K
});
BDataType
,
Tensor
<
float
>
b
({
K
,
N
});
CDataType
,
AccDataType
,
for
(
int
m
=
0
;
m
<
M
;
m
++
)
float
,
{
PassThrough
,
for
(
int
k
=
0
;
k
<
K
;
k
++
)
PassThrough
,
{
PassThrough
,
a
(
m
,
k
)
=
ck
::
type_convert
<
float
>
(
a_m_k
(
m
,
k
))
*
float
,
ck
::
type_convert
<
float
>
(
a_m_k_scale
(
m
,
k
/
Scale_Block_K
));
float
>
;
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
b
(
k
,
n
)
=
ck
::
type_convert
<
float
>
(
b_k_n
(
k
,
n
))
*
ck
::
type_convert
<
float
>
(
b_k_n_scale
(
k
/
Scale_Block_K
,
n
));
}
}
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
CShuffleDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
ref_gemm
.
MakeArgument
(
a
,
b
,
c
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
a_m_k_scale
,
b_k_n
,
b_k_n_scale
,
c_m_n_host_result
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
@@ -366,8 +353,9 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -366,8 +353,9 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
<<
((
res_verified
)
?
" (PASSED!)"
:
" (FAILED!)"
)
<<
std
::
endl
;
<<
((
res_verified
)
?
" (PASSED!)"
:
" (FAILED!)"
)
<<
std
::
endl
;
}
}
res_verified
=
res_verified
&&
res_verified
=
res_verified
&&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c
,
"Error: Incorrect results!"
);
c_m_n_host_result
,
"Error: Incorrect results!"
);
if
(
config
.
verbosity
>
0
&&
res_verified
)
if
(
config
.
verbosity
>
0
&&
res_verified
)
std
::
cout
<<
"Done."
<<
std
::
endl
;
std
::
cout
<<
"Done."
<<
std
::
endl
;
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
8ce41034
...
@@ -12,7 +12,13 @@
...
@@ -12,7 +12,13 @@
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
...
@@ -25,7 +31,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -25,7 +31,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// This part comes from the Codegen
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
...
@@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
using
CodegenGemmShape
=
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
...
@@ -99,12 +105,32 @@ int run_gemm_example(int argc, char* argv[])
...
@@ -99,12 +105,32 @@ int run_gemm_example(int argc, char* argv[])
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
}
else
else
{
{
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
8ce41034
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_
MEMORY
)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_
COMPUTE
)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
...
@@ -43,6 +43,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
...
@@ -43,6 +43,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM.
// ToDo: Add more bias config to support different categories of GEMM.
};
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
bf16_t
>
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
BDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
bf16_t
;
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
fp8_t
>
{
using
ADataType
=
ck_tile
::
fp8_t
;
using
BDataType
=
ck_tile
::
fp8_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
bf8_t
>
{
using
ADataType
=
ck_tile
::
bf8_t
;
using
BDataType
=
ck_tile
::
bf8_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
template
<
typename
T
>
template
<
typename
T
>
struct
DataTypeTraits
;
struct
DataTypeTraits
;
...
@@ -64,13 +91,23 @@ struct DataTypeTraits<ck_tile::half_t>
...
@@ -64,13 +91,23 @@ struct DataTypeTraits<ck_tile::half_t>
static
constexpr
const
char
*
name
=
"fp16"
;
static
constexpr
const
char
*
name
=
"fp16"
;
};
};
using
Types
=
GemmBasicTypeConfig
<
ck_tile
::
half_t
>
;
template
<
>
struct
DataTypeTraits
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// Specific type aliases for easy access
template
<
>
using
ADataType
=
Types
::
ADataType
;
struct
DataTypeTraits
<
ck_tile
::
fp8_t
>
using
BDataType
=
Types
::
BDataType
;
{
using
AccDataType
=
Types
::
AccDataType
;
static
constexpr
const
char
*
name
=
"fp8"
;
using
CDataType
=
Types
::
CDataType
;
};
template
<
>
struct
DataTypeTraits
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
8ce41034
...
@@ -9,6 +9,7 @@ static constexpr inline auto is_row_major(Layout layout_)
...
@@ -9,6 +9,7 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>>
{};
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>>
{};
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
const
float
max_accumulated_value
)
...
@@ -29,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
...
@@ -29,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
...
@@ -55,7 +57,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -55,7 +57,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args
.
stride_B
=
stride_B
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
args
.
stride_C
=
stride_C
;
float
ave_time
=
gemm_calc
<
ALayout
,
BLayout
,
CLayout
>
(
float
ave_time
=
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
...
@@ -66,13 +69,19 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -66,13 +69,19 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" A_Layout ="
<<
ALayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
<<
std
::
endl
;
return
ave_time
;
return
ave_time
;
}
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
PrecType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
int
run_gemm_example_with_layouts
(
int
argc
,
int
run_gemm_example_with_layouts
(
int
argc
,
char
*
argv
[],
char
*
argv
[],
const
ALayout
a_layout
=
ALayout
{},
const
ALayout
a_layout
=
ALayout
{},
...
@@ -83,6 +92,11 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -83,6 +92,11 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
...
@@ -119,7 +133,8 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -119,7 +133,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
c_m_n_dev_buf
,
M
,
M
,
...
@@ -145,7 +160,8 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -145,7 +160,8 @@ int run_gemm_example_with_layouts(int argc,
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
...
@@ -202,7 +218,8 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -202,7 +218,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
const
float
max_accumulated_value
=
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
...
...
example/ck_tile/03_gemm/script/benchmark_basic.sh
View file @
8ce41034
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
1
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
b_matrix_layout
in
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
...
...
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
0 → 100644
View file @
8ce41034
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
0 → 100644
View file @
8ce41034
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
0 → 100644
View file @
8ce41034
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp8
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
View file @
8ce41034
...
@@ -2,10 +2,10 @@
...
@@ -2,10 +2,10 @@
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
b_matrix_layout
in
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
...
...
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
0 → 100644
View file @
8ce41034
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
bf16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
0 → 100644
View file @
8ce41034
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
bf8
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
0 → 100644
View file @
8ce41034
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp8
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/smoke_test_basic.sh
View file @
8ce41034
...
@@ -7,22 +7,20 @@ export CK_REPEAT=1
...
@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
run_fp16_tests
()
{
run_tests
()
{
for
batch
in
1 2
;
do
for
m
in
128 1024
;
do
for
m
in
128 1024
;
do
for
n
in
128 2048
;
do
for
n
in
128 2048
;
do
for
k
in
64 128
;
do
for
k
in
32 64
;
do
$EXE
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-prec
=
$1
$COMMON_ARGS
$EXE
-b
=
$batch
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-e
=
1e-5
-prec
=
fp16
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with m=
$m
, n=
$n
, k=
$k
executed successfully."
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
else
else
echo
"Error: Test with m=
$m
, n=
$n
, k=
$k
failed to execute properly."
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# Optionally, exit or break if you need to halt further execution
# exit 1
# exit 1
fi
fi
done
done
done
done
done
done
done
...
@@ -30,6 +28,9 @@ run_fp16_tests() {
...
@@ -30,6 +28,9 @@ run_fp16_tests() {
set
-x
set
-x
run_fp16_tests
run_tests
"fp16"
run_tests
"bf16"
run_tests
"fp8"
run_tests
"bf8"
set
+x
set
+x
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
View file @
8ce41034
...
@@ -7,22 +7,20 @@ export CK_REPEAT=1
...
@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
run_fp16_tests
()
{
run_tests
()
{
for
batch
in
1 2
;
do
for
m
in
512 1024
;
do
for
m
in
128 1024
;
do
for
n
in
512 2048
;
do
for
n
in
128 2048
;
do
for
k
in
512 1024
;
do
for
k
in
32 64
;
do
$EXE
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-prec
=
$1
$COMMON_ARGS
$EXE
-b
=
$batch
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-e
=
1e-5
-prec
=
fp16
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
else
else
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# Optionally, exit or break if you need to halt further execution
# exit 1
# exit 1
fi
fi
done
done
done
done
done
done
done
...
@@ -30,6 +28,9 @@ run_fp16_tests() {
...
@@ -30,6 +28,9 @@ run_fp16_tests() {
set
-x
set
-x
run_fp16_tests
run_tests
"fp16"
run_tests
"bf16"
run_tests
"fp8"
run_tests
"bf8"
set
+x
set
+x
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
8ce41034
...
@@ -12,7 +12,13 @@
...
@@ -12,7 +12,13 @@
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
...
@@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// Compute friendly for Intrawave scheduler
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
...
@@ -243,24 +249,101 @@ int run_gemm_example(int argc, char* argv[])
...
@@ -243,24 +249,101 @@ int run_gemm_example(int argc, char* argv[])
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
}
else
else
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
8ce41034
...
@@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
arg
.
Print
();
arg
.
Print
();
GridwiseGemm
::
BlockwiseGemmPipe
::
HotLoopInstList
::
Print
();
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
...
@@ -745,7 +746,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -745,7 +746,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
<<
"BlkGemmPipelineVersion: "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
"BlkGemmPipelinePrefetchStages: "
<<
"BlkGemmPipelinePrefetchStages: "
<<
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
;
<<
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
<<
", "
<<
"Kpack: "
<<
GridwiseGemm
::
BlockwiseGemmPipe
::
AMmaKStride
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/utility/blkgemmpipe_scheduler.hpp
View file @
8ce41034
...
@@ -103,14 +103,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
...
@@ -103,14 +103,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL
);
KPerXDL
);
printf
(
" A/B buffer load inst: %d, %d
\n
A/B LDS write inst: %d, %d
\n
A/B LDS read inst: "
printf
(
" A/B buffer load inst: %d, %d
\n
A/B LDS write inst: %d, %d
\n
A/B LDS read inst: "
"%d, %d
\n
C MFMA inst: %d
\n
"
,
"%d, %d
\n
C MFMA inst: %d
\n
"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d
\n
"
,
A_Buffer_Load_Inst_Num
,
A_Buffer_Load_Inst_Num
,
B_Buffer_Load_Inst_Num
,
B_Buffer_Load_Inst_Num
,
A_LDS_Write_Inst_Num
,
A_LDS_Write_Inst_Num
,
B_LDS_Write_Inst_Num
,
B_LDS_Write_Inst_Num
,
A_LDS_Read_Inst_Num
,
A_LDS_Read_Inst_Num
,
B_LDS_Read_Inst_Num
,
B_LDS_Read_Inst_Num
,
C_MFMA_Inst_Num
);
C_MFMA_Inst_Num
,
A_LDS_Read_Width
,
B_LDS_Read_Width
,
ALDSWriteWidth
,
BLDSWriteWidth
,
ABufferLoadWidth
,
BBufferLoadWidth
);
}
}
};
};
...
...
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
View file @
8ce41034
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
...
@@ -8,16 +8,75 @@
...
@@ -8,16 +8,75 @@
namespace
ck_tile
{
namespace
ck_tile
{
CK_TILE_HOST_DEVICE
bf16_t
add_bf16_t
(
const
bf16_t
&
a
,
const
bf16_t
&
b
)
template
<
typename
T
,
typename
ComputeType
>
CK_TILE_HOST_DEVICE
T
add
(
const
T
&
a
,
const
T
&
b
)
{
{
return
type_convert
<
bf16_t
>
(
type_convert
<
float
>
(
a
)
+
type_convert
<
float
>
(
b
));
return
type_convert
<
T
>
(
type_convert
<
ComputeType
>
(
a
)
+
type_convert
<
ComputeType
>
(
b
));
}
}
CK_TILE_HOST_DEVICE
bf16x2_t
add_bf16x2_t
(
const
bf16x2_t
&
a
,
const
bf16x2_t
&
b
)
CK_TILE_HOST_DEVICE
bf16x2_t
add_bf16x2_t
(
const
bf16x2_t
&
a
,
const
bf16x2_t
&
b
)
{
{
bf16x2_t
rtn
;
bf16x2_t
rtn
;
rtn
[
0
]
=
add_bf16_t
(
a
[
0
],
b
[
0
]);
rtn
[
0
]
=
add
<
bf16_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add_bf16_t
(
a
[
1
],
b
[
1
]);
rtn
[
1
]
=
add
<
bf16_t
,
float
>
(
a
[
1
],
b
[
1
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf16x4_t
add_bf16x4_t
(
const
bf16x4_t
&
a
,
const
bf16x4_t
&
b
)
{
bf16x4_t
rtn
;
rtn
[
0
]
=
add
<
bf16_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf16_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf16_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf16_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
fp8x4_t
add_fp8x4_t
(
const
fp8x4_t
&
a
,
const
fp8x4_t
&
b
)
{
fp8x4_t
rtn
;
rtn
[
0
]
=
add
<
fp8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
fp8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
fp8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
fp8_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
fp8x8_t
add_fp8x8_t
(
const
fp8x8_t
&
a
,
const
fp8x8_t
&
b
)
{
fp8x8_t
rtn
;
rtn
[
0
]
=
add
<
fp8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
fp8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
fp8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
fp8_t
,
float
>
(
a
[
3
],
b
[
3
]);
rtn
[
4
]
=
add
<
fp8_t
,
float
>
(
a
[
4
],
b
[
4
]);
rtn
[
5
]
=
add
<
fp8_t
,
float
>
(
a
[
5
],
b
[
5
]);
rtn
[
6
]
=
add
<
fp8_t
,
float
>
(
a
[
6
],
b
[
6
]);
rtn
[
7
]
=
add
<
fp8_t
,
float
>
(
a
[
7
],
b
[
7
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf8x4_t
add_bf8x4_t
(
const
bf8x4_t
&
a
,
const
bf8x4_t
&
b
)
{
bf8x4_t
rtn
;
rtn
[
0
]
=
add
<
bf8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf8_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf8x8_t
add_bf8x8_t
(
const
bf8x8_t
&
a
,
const
bf8x8_t
&
b
)
{
bf8x8_t
rtn
;
rtn
[
0
]
=
add
<
bf8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf8_t
,
float
>
(
a
[
3
],
b
[
3
]);
rtn
[
4
]
=
add
<
bf8_t
,
float
>
(
a
[
4
],
b
[
4
]);
rtn
[
5
]
=
add
<
bf8_t
,
float
>
(
a
[
5
],
b
[
5
]);
rtn
[
6
]
=
add
<
bf8_t
,
float
>
(
a
[
6
],
b
[
6
]);
rtn
[
7
]
=
add
<
bf8_t
,
float
>
(
a
[
7
],
b
[
7
]);
return
rtn
;
return
rtn
;
}
}
...
@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
...
@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
}
while
(
cur_v
.
u32
!=
old_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf16x4_t
>
(
bf16x4_t
*
p_dst
,
bf16x4_t
const
&
x
)
{
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
union
U64BF164_ADDR
{
uint64_t
*
u64_a
;
bf16x4_t
*
bf164_a
;
};
// Union to treat the data as either bf16x4_t or 64-bit integer
union
U64BF164
{
uint64_t
u64
;
bf16x4_t
bf164
;
};
U64BF164_ADDR
addr
;
addr
.
bf164_a
=
p_dst
;
// interpret p_dst as a 64-bit location
// First read (non-atomic) of the old value
U64BF164
cur_v
;
cur_v
.
u64
=
*
addr
.
u64_a
;
U64BF164
new_v_union
;
uint64_t
old_v
,
new_v
;
do
{
// old 64 bits
old_v
=
cur_v
.
u64
;
// Add elementwise in bf16
new_v_union
.
bf164
=
add_bf16x4_t
(
cur_v
.
bf164
,
x
);
new_v
=
new_v_union
.
u64
;
// Attempt the 64-bit CAS
cur_v
.
u64
=
atomicCAS
(
addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
fp8x4_t
>
(
fp8x4_t
*
p_dst
,
const
fp8x4_t
&
x
)
{
union
U32FP84_ADDR
{
uint32_t
*
u32_a
;
fp8x4_t
*
fp84_a
;
};
union
U32FP84
{
uint32_t
u32
;
fp8x4_t
fp84
;
};
U32FP84_ADDR
dword_addr
;
U32FP84
cur_v
;
U32FP84
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
fp84_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
fp84
=
add_fp8x4_t
(
cur_v
.
fp84
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf8x4_t
>
(
bf8x4_t
*
p_dst
,
const
bf8x4_t
&
x
)
{
union
U32BF84_ADDR
{
uint32_t
*
u32_a
;
bf8x4_t
*
bf84_a
;
};
union
U32BF84
{
uint32_t
u32
;
bf8x4_t
bf84
;
};
U32BF84_ADDR
dword_addr
;
U32BF84
cur_v
;
U32BF84
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
bf84_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
bf84
=
add_bf8x4_t
(
cur_v
.
bf84
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
//
// Atomic add for fp8x8_t
//
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
fp8x8_t
>
(
fp8x8_t
*
p_dst
,
fp8x8_t
const
&
x
)
{
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
union
U64FP88_ADDR
{
uint64_t
*
u64_a
;
// pointer to 64-bit integer
fp8x8_t
*
fp88_a
;
// pointer to fp8x8_t
};
union
U64FP88
{
uint64_t
u64
;
fp8x8_t
fp88
;
};
U64FP88_ADDR
dword_addr
;
U64FP88
cur_v
;
U64FP88
new_v_union
;
uint64_t
old_v
,
new_v
;
// Point to the destination as both fp8x8_t* and uint64_t*.
dword_addr
.
fp88_a
=
p_dst
;
// Initial read of 64 bits from memory
cur_v
.
u64
=
*
dword_addr
.
u64_a
;
do
{
old_v
=
cur_v
.
u64
;
// Add each fp8 element using your add_fp8x8_t(...) routine
new_v_union
.
fp88
=
add_fp8x8_t
(
cur_v
.
fp88
,
x
);
new_v
=
new_v_union
.
u64
;
// Attempt 64-bit CAS
cur_v
.
u64
=
atomicCAS
(
dword_addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
//
// Atomic add for bf8x8_t
//
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf8x8_t
>
(
bf8x8_t
*
p_dst
,
bf8x8_t
const
&
x
)
{
union
U64BF88_ADDR
{
uint64_t
*
u64_a
;
bf8x8_t
*
bf88_a
;
};
union
U64BF88
{
uint64_t
u64
;
bf8x8_t
bf88
;
};
U64BF88_ADDR
dword_addr
;
U64BF88
cur_v
;
U64BF88
new_v_union
;
uint64_t
old_v
,
new_v
;
dword_addr
.
bf88_a
=
p_dst
;
// Read the original 64 bits
cur_v
.
u64
=
*
dword_addr
.
u64_a
;
do
{
old_v
=
cur_v
.
u64
;
// Add each bf8 element using your add_bf8x8_t(...) routine
new_v_union
.
bf88
=
add_bf8x8_t
(
cur_v
.
bf88
,
x
);
new_v
=
new_v_union
.
u64
;
// 64-bit CAS loop
cur_v
.
u64
=
atomicCAS
(
dword_addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
CK_TILE_DEVICE
void
atomic_add_g
(
T
*
p_dst
,
const
thread_buffer
<
T
,
N
>&
x
)
CK_TILE_DEVICE
void
atomic_add_g
(
T
*
p_dst
,
const
thread_buffer
<
T
,
N
>&
x
)
{
{
...
@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
...
@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(
std
::
is_same
<
T
,
uint32_t
>::
value
&&
(
N
==
1
))
||
(
std
::
is_same
<
T
,
uint32_t
>::
value
&&
(
N
==
1
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
2
||
N
==
4
)),
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
))
||
"wrong! not implemented"
);
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
4
||
N
==
8
||
N
==
16
)),
"The granularity of the thread buffer is unsupported on the hardware!"
);
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
...
@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
...
@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
}
}
else
if
constexpr
(
N
==
4
)
else
if
constexpr
(
N
==
4
)
{
{
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x2_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x4_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
)
+
1
,
}
x
.
template
get_as
<
bf16x2_t
>()[
I1
]);
else
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x4_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf16x4_t
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
fp8_t
>::
value
)
{
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x4_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x4_t
>()[
I0
]);
}
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x8_t
>()[
I0
]);
}
if
constexpr
(
N
==
16
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x8_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
fp8x8_t
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
bf8_t
>::
value
)
{
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x4_t
>()[
I0
]);
}
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x8_t
>()[
I0
]);
}
if
constexpr
(
N
==
16
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x8_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf8x8_t
>()[
I1
]);
}
}
}
}
}
}
...
...
include/ck_tile/host.hpp
View file @
8ce41034
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
...
@@ -34,4 +35,3 @@
...
@@ -34,4 +35,3 @@
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
#include "ck_tile/host/timer.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment