Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7fb0b322
Commit
7fb0b322
authored
Oct 21, 2024
by
chenjun
Browse files
add int8 gemm multiply multiply a8w8
parent
95e722a3
Changes
17
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
694 additions
and
177 deletions
+694
-177
include/ck/host_utility/flush_cache.hpp
include/ck/host_utility/flush_cache.hpp
+35
-17
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+18
-0
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+105
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
...ration_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
+10
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
+99
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
...ultiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
...ltiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
...tiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
...iply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
...tiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
...iply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
+33
-0
profiler/fp8_gmm_profiler.sh
profiler/fp8_gmm_profiler.sh
+31
-0
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
.../include/profiler/profile_gemm_multiply_multiply_impl.hpp
+15
-14
profiler/int8_gmm_profiler.sh
profiler/int8_gmm_profiler.sh
+31
-0
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+144
-144
profiler/src/profile_gemm_multiply_multiply.cpp
profiler/src/profile_gemm_multiply_multiply.cpp
+9
-1
No files found.
include/ck/host_utility/flush_cache.hpp
View file @
7fb0b322
...
@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
Args
...
args
)
Args
...
args
)
{
{
#if CK_TIME_KERNEL
#if CK_TIME_KERNEL
#define MEDIAN
1
#define MEDIAN
0
if
(
stream_config
.
time_kernel_
)
if
(
stream_config
.
time_kernel_
)
{
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
...
@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#else
#else
float
total_time
=
0
;
float
total_time
=
0
;
#endif
#endif
hipEvent_t
start
,
stop
;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
if
constexpr
(
!
TimePreprocess
)
if
constexpr
(
!
TimePreprocess
)
...
@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
preprocess
();
preprocess
();
}
}
hipEvent_t
start
,
stop
;
//
hipEvent_t start, stop;
hip_check_error
(
hipEventCreate
(
&
start
));
//
hip_check_error(hipEventCreate(&start));
hip_check_error
(
hipEventCreate
(
&
stop
));
//
hip_check_error(hipEventCreate(&stop));
hip_check_error
(
hipDeviceSynchronize
());
//
hip_check_error(hipDeviceSynchronize());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
//
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time
// calculate preprocess time
if
constexpr
(
TimePreprocess
)
if
constexpr
(
TimePreprocess
)
{
{
...
@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
// end real kernel
// end real kernel
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
//
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error
(
hipEventSynchronize
(
stop
));
//
hip_check_error(hipEventSynchronize(stop));
float
cur_time
=
0
;
//
float cur_time = 0;
hip_check_error
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
//
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
//
#if MEDIAN
times
.
insert
(
cur_time
);
//
times.insert(cur_time);
#else
//
#else
total_time
+=
cur_time
;
//
total_time += cur_time;
#endif
//
#endif
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"i: "
<<
i
<<
" cur_time: "
<<
cur_time
<<
std
::
endl
;
//
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
printf
(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p
\n
"
,
printf
(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p
\n
"
,
static_cast
<
const
void
*>
(
gemm_args
.
p_a_grid
),
static_cast
<
const
void
*>
(
gemm_args
.
p_a_grid
),
static_cast
<
const
void
*>
(
gemm_args
.
p_b_grid
));
static_cast
<
const
void
*>
(
gemm_args
.
p_b_grid
));
}
}
}
}
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
float
cur_time
=
0
;
hip_check_error
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
#if MEDIAN
times
.
insert
(
cur_time
);
#else
total_time
+=
cur_time
;
#endif
#if MEDIAN
#if MEDIAN
auto
mid
=
times
.
begin
();
auto
mid
=
times
.
begin
();
...
@@ -333,7 +350,8 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -333,7 +350,8 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return
(
*
mid
+
*
mid_next
)
/
2
;
return
(
*
mid
+
*
mid_next
)
/
2
;
}
}
#else
#else
return
total_time
/
nrepeat
;
// return total_time / nrepeat;
return
(
total_time
-
0.01
*
nrepeat
)
/
nrepeat
;
#endif
#endif
}
}
else
else
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
7fb0b322
...
@@ -272,6 +272,24 @@ struct MultiplyMultiply
...
@@ -272,6 +272,24 @@ struct MultiplyMultiply
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
half_t
,
int
,
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
e
,
const
int
&
c
,
const
ck
::
half_t
&
d0
,
const
ck
::
half_t
&
d1
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
);
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
x0_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
bhalf_t
,
int
,
float
,
float
>
(
ck
::
bhalf_t
&
e
,
const
int
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
);
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
};
};
struct
MultiplyAddFastGelu
struct
MultiplyAddFastGelu
...
...
include/ck/utility/amd_xdlops.hpp
View file @
7fb0b322
...
@@ -327,7 +327,7 @@ struct intrin_mfma_i32_16x16x32i8<16, 16>
...
@@ -327,7 +327,7 @@ struct intrin_mfma_i32_16x16x32i8<16, 16>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x32i8
(
bit_cast
<
int64_t
>
(
reg_a
),
__builtin_amdgcn_mfma_i32_16x16x32
_
i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
View file @
7fb0b322
...
@@ -96,6 +96,87 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
...
@@ -96,6 +96,87 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
#endif
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
...
@@ -155,6 +236,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -155,6 +236,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
}
}
#endif
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
View file @
7fb0b322
...
@@ -8,9 +8,19 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
...
@@ -8,9 +8,19 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
)
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
add_instance_library
(
device_gemm_multiply_multiply_instance
${
GEMM_MULTIPLY_MULTIPLY_INSTANCES
}
)
add_instance_library
(
device_gemm_multiply_multiply_instance
${
GEMM_MULTIPLY_MULTIPLY_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
0 → 100644
View file @
7fb0b322
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
0 → 100644
View file @
7fb0b322
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
0 → 100644
View file @
7fb0b322
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances
<
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
0 → 100644
View file @
7fb0b322
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
0 → 100644
View file @
7fb0b322
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
0 → 100644
View file @
7fb0b322
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Interwave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
0 → 100644
View file @
7fb0b322
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Interwave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/fp8_gmm_profiler.sh
0 → 100644
View file @
7fb0b322
EXE
=
"
$(
find
.
-name
ckProfiler
-type
f |
head
-n
1
)
"
op
=
"gemm_multiply_multiply"
loopFunc
()
{
N
=
$1
K
=
$2
$EXE
$op
7 1 0 2 0 1 1
$N
$K
-1
-1
0 0
-1
1 40 500 4096
for
((
M
=
32
;
M<
=
20480
;
M
*
=
2
))
do
# echo "M = $M, N = $N, K = $K"
$EXE
$op
7 1 0 2 0 1
$M
$N
$K
-1
-1
0 0
-1
1 40 500 4096
done
$EXE
$op
7 1 0 2 0 1 20480
$N
$K
-1
-1
0 0
-1
1 40 500 4096
}
N
=
4608
K
=
3584
loopFunc
$N
$K
N
=
3584
K
=
3584
loopFunc
$N
$K
N
=
3584
K
=
20480
loopFunc
$N
$K
N
=
40960
K
=
3584
loopFunc
$N
$K
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
View file @
7fb0b322
...
@@ -84,12 +84,12 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -84,12 +84,12 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
std
::
min
(
n_iter
,
std
::
min
(
n_iter
,
static_cast
<
int
>
(
std
::
ceil
(
static_cast
<
double
>
(
rotating
)
/
total_gemm_needed
))));
static_cast
<
int
>
(
std
::
ceil
(
static_cast
<
double
>
(
rotating
)
/
total_gemm_needed
))));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
//
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
//
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std
::
cout
<<
"d0_m_n: "
<<
d0_m_n
.
mDesc
<<
std
::
endl
;
//
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std
::
cout
<<
"d1_m_n: "
<<
d1_m_n
.
mDesc
<<
std
::
endl
;
//
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_device_result
.
mDesc
<<
std
::
endl
;
//
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
std
::
cout
<<
"rotating count: "
<<
rotating_count
<<
std
::
endl
;
//
std::cout << "rotating count: " << rotating_count << std::endl;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -146,7 +146,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -146,7 +146,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
//
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
// Run reference GEMM
if
(
do_verification
)
if
(
do_verification
)
...
@@ -267,14 +267,15 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -267,14 +267,15 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
//
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
//
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
//
<< kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_INT8
// set softer tolerances for fp8
// set softer tolerances for fp8
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
if
constexpr
((
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
EDataType
,
f8_t
>
)
is_same_v
<
EDataType
,
f8_t
>
)
||
(
is_same_v
<
ADataType
,
int8_t
>
||
is_same_v
<
BDataType
,
int8_t
>
||
is_same_v
<
EDataType
,
int8_t
>
))
{
{
std
::
string
msg
=
"Error: Incorrect results!"
;
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-1
;
double
rtol
=
1e-1
;
...
@@ -286,7 +287,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -286,7 +287,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
{
{
#endif
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_INT8
}
}
#endif
#endif
...
...
profiler/int8_gmm_profiler.sh
0 → 100644
View file @
7fb0b322
EXE
=
"
$(
find
.
-name
ckProfiler
-type
f |
head
-n
1
)
"
op
=
"gemm_multiply_multiply"
loopFunc
()
{
N
=
$1
K
=
$2
$EXE
$op
8 1 0 2 0 1 1
$N
$K
-1
-1
0 0
-1
1 40 500 4096
for
((
M
=
32
;
M<
=
20480
;
M
*
=
2
))
do
# echo "M = $M, N = $N, K = $K"
$EXE
$op
8 1 0 2 0 1
$M
$N
$K
-1
-1
0 0
-1
1 40 500 4096
done
$EXE
$op
8 1 0 2 0 1 20480
$N
$K
-1
-1
0 0
-1
1 40 500 4096
}
# N=4608
# K=3584
# loopFunc $N $K
N
=
3584
K
=
3584
loopFunc
$N
$K
N
=
3584
K
=
20480
loopFunc
$N
$K
N
=
40960
K
=
3584
loopFunc
$N
$K
profiler/src/CMakeLists.txt
View file @
7fb0b322
This diff is collapsed.
Click to expand it.
profiler/src/profile_gemm_multiply_multiply.cpp
View file @
7fb0b322
...
@@ -27,6 +27,7 @@ enum struct GemmDataType
...
@@ -27,6 +27,7 @@ enum struct GemmDataType
F16_F8_F16
,
// 5
F16_F8_F16
,
// 5
F16_F16_F16_F8
,
// 6
F16_F16_F16_F8
,
// 6
F8_F8_BF16
,
// 7
F8_F8_BF16
,
// 7
INT8_INT8_BF16
,
// 8
};
};
#define OP_NAME "gemm_multiply_multiply"
#define OP_NAME "gemm_multiply_multiply"
...
@@ -39,7 +40,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
...
@@ -39,7 +40,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"f16->f8; 7: f8->bf16, "
"comp f8)
\n
"
);
"comp f8
; 8: int8->bf16
)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
@@ -89,6 +90,8 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
...
@@ -89,6 +90,8 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
using
F32
=
float
;
using
F32
=
float
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
I8
=
int8_t
;
using
I32
=
int
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -162,6 +165,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
...
@@ -162,6 +165,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
return
profile
(
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
F32
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{},
Col
{},
Row
{});
F8
{},
F8
{},
F8
{},
F32
{},
F32
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{},
Col
{},
Row
{});
}
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
I8
{},
I8
{},
I8
{},
I32
{},
F32
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{},
Col
{},
Row
{});
}
else
else
{
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment