Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2e7c4ad7
Commit
2e7c4ad7
authored
Dec 03, 2023
by
Jing Zhang
Browse files
clean code
parent
c7d5c772
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
445 additions
and
190 deletions
+445
-190
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
+2
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+4
-4
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+262
-10
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+3
-2
library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
.../tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
+22
-20
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+101
-101
profiler/src/profile_gemm_splitk.cpp
profiler/src/profile_gemm_splitk.cpp
+46
-49
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+2
-2
No files found.
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
View file @
2e7c4ad7
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
...
@@ -30,7 +31,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -30,7 +31,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F
16
;
using
BDataType
=
F
8
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
2e7c4ad7
...
@@ -197,7 +197,9 @@ struct PassThrough
...
@@ -197,7 +197,9 @@ struct PassThrough
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
f8_t
>
(
half_t
&
y
,
const
f8_t
&
x
)
const
__host__
__device__
void
operator
()
<
half_t
,
f8_t
>
(
half_t
&
y
,
const
f8_t
&
x
)
const
{
{
y
=
type_convert
<
half_t
>
(
x
);
const
uint16_t
tmp
=
bit_cast
<
uint8_t
>
(
x
);
y
=
bit_cast
<
half_t
>
(
tmp
);
// y = type_convert<half_t>(x);
}
}
template
<
>
template
<
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
2e7c4ad7
...
@@ -401,7 +401,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -401,7 +401,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
ComputeType
),
return
math
::
max
((
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
ComputeType
),
c_block_size
*
sizeof
(
Float
C
));
c_block_size
*
sizeof
(
Float
Acc
));
}
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
...
@@ -891,7 +891,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -891,7 +891,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
C
*>
(
p_shared_block
),
static_cast
<
Float
Acc
*>
(
p_shared_block
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
...
@@ -942,7 +942,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -942,7 +942,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// VGPR to LDS
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
Float
C
,
Float
Acc
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -982,7 +982,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -982,7 +982,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Float
C
,
// typename SrcData,
Float
Acc
,
// typename SrcData,
FloatC
,
// typename DstData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
2e7c4ad7
...
@@ -278,7 +278,7 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
...
@@ -278,7 +278,7 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.f32"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.f32"
);
// buffer atomic-
add
fp
32
// buffer atomic-
max
fp
64
__device__
double
__device__
double
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
double
vdata
,
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
double
vdata
,
int32x4_t
rsrc
,
// dst_wave_buffer_resource
int32x4_t
rsrc
,
// dst_wave_buffer_resource
...
@@ -420,10 +420,124 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
...
@@ -420,10 +420,124 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
if
constexpr
(
is_same
<
T
,
float
>::
value
)
// fp32
auto
raw_data
=
amd_buffer_load_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
{
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
);
if
constexpr
(
N
==
1
)
return
bit_cast
<
r_t
>
(
raw_data
);
{
return
llvm_amdgcn_raw_buffer_load_fp32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
;
tmp
.
AsType
<
float4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
float4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
float8_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
// fp16
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_fp16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
// use fp32 load to mimic fp16 load
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
half8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
// bf16
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
bhalf8_t
>
(
tmp
);
}
}
else
// other datatype
{
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
auto
raw_data
=
amd_buffer_load_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
);
return
bit_cast
<
r_t
>
(
raw_data
);
}
}
}
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
...
@@ -542,12 +656,150 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
...
@@ -542,12 +656,150 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
int8_t
,
sizeof
(
T
)
*
N
>::
type
;
if
constexpr
(
is_same
<
T
,
float
>::
value
)
// fp32
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp32
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
// fp16
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
#if 0
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
// bf16
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
bhalf_t
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
bhalf_t
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
{
using
r_t
=
typename
vector_type
<
int8_t
,
sizeof
(
T
)
*
N
>::
type
;
amd_buffer_store_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
bit_cast
<
r_t
>
(
src_thread_data
),
amd_buffer_store_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
bit_cast
<
r_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
);
dst_wave_addr_offset
);
}
}
}
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
...
include/ck/utility/type_convert.hpp
View file @
2e7c4ad7
...
@@ -367,8 +367,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
...
@@ -367,8 +367,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
// constexpr bool negative_zero_nan = true;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
// return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
return
static_cast
<
half_t
>
(
bit_cast
<
int8_t
>
(
x
));
#endif
#endif
}
}
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
View file @
2e7c4ad7
set
(
GEMM_SPLITK_INSTANCES
)
set
(
GEMM_SPLITK_INSTANCES
)
list
(
APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
list
(
APPEND GEMM_SPLITK_INSTANCES
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
#device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp
#device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp
#device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp
#device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp
#device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
)
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
)
add_instance_library
(
device_gemm_splitk_instance
${
GEMM_SPLITK_INSTANCES
}
)
add_instance_library
(
device_gemm_splitk_instance
${
GEMM_SPLITK_INSTANCES
}
)
profiler/src/CMakeLists.txt
View file @
2e7c4ad7
# ckProfiler
# ckProfiler
set
(
PROFILER_SOURCES
set
(
PROFILER_SOURCES
profiler.cpp
profiler.cpp
profile_gemm.cpp
#
profile_gemm.cpp
profile_gemm_splitk.cpp
profile_gemm_splitk.cpp
profile_gemm_bias_add_reduce.cpp
#
profile_gemm_bias_add_reduce.cpp
profile_gemm_add_multiply.cpp
#
profile_gemm_add_multiply.cpp
profile_gemm_multiply_add.cpp
#
profile_gemm_multiply_add.cpp
profile_gemm_reduce.cpp
#
profile_gemm_reduce.cpp
profile_batched_gemm.cpp
#
profile_batched_gemm.cpp
profile_batched_gemm_reduce.cpp
#
profile_batched_gemm_reduce.cpp
profile_conv_fwd.cpp
#
profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp
#
profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp
#
profile_conv_fwd_bias_relu_add.cpp
profile_conv_bwd_data.cpp
#
profile_conv_bwd_data.cpp
profile_grouped_conv_fwd.cpp
#
profile_grouped_conv_fwd.cpp
profile_grouped_conv_bwd_weight.cpp
#
profile_grouped_conv_bwd_weight.cpp
profile_reduce.cpp
#
profile_reduce.cpp
profile_groupnorm_fwd.cpp
#
profile_groupnorm_fwd.cpp
profile_layernorm_fwd.cpp
#
profile_layernorm_fwd.cpp
profile_max_pool3d_fwd.cpp
#
profile_max_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
#
profile_avg_pool3d_bwd.cpp
profile_max_pool3d_bwd.cpp
#
profile_max_pool3d_bwd.cpp
profile_softmax.cpp
#
profile_softmax.cpp
profile_batchnorm_fwd.cpp
#
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
#
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
#
profile_batchnorm_infer.cpp
profile_grouped_conv_bwd_data.cpp
#
profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp
#
profile_conv_tensor_rearrange.cpp
)
)
if
(
DL_KERNELS
)
#
if(DL_KERNELS)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp
)
#
list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
endif
()
#
endif()
#
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
#
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp
)
#
list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
list
(
APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp
)
#
list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
list
(
APPEND PROFILER_SOURCES profile_gemm_streamk.cpp
)
#
list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
list
(
APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp
)
#
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp
)
#
list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp
)
#
list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp
)
#
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp
)
#
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm.cpp
)
#
list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp
)
#
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
endif
()
#
endif()
#
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
#
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list
(
APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp
)
#
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list
(
APPEND PROFILER_SOURCES profile_contraction_scale.cpp
)
#
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
endif
()
#
endif()
set
(
PROFILER_EXECUTABLE ckProfiler
)
set
(
PROFILER_EXECUTABLE ckProfiler
)
...
@@ -57,60 +57,60 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
...
@@ -57,60 +57,60 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
target_compile_options
(
${
PROFILER_EXECUTABLE
}
PRIVATE -Wno-global-constructors
)
target_compile_options
(
${
PROFILER_EXECUTABLE
}
PRIVATE -Wno-global-constructors
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE utility
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE utility
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_splitk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_splitk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_multiply_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_add_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_reduce_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bias_add_reduce_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_reduce_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_fwd_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_fwd_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_fwd_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv1d_bwd_data_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_bwd_data_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv3d_bwd_data_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_bwd_weight_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_weight_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_weight_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_bias_relu_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_bias_relu_add_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_normalization_fwd_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_softmax_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_reduce_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batchnorm_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_pool3d_fwd_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_avg_pool3d_bwd_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_max_pool_bwd_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_data_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_data_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_image_to_column_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_column_to_image_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
#
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
#
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
endif
()
#
endif()
#
#
#
if
(
DL_KERNELS
)
#
if(DL_KERNELS)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_multi_d_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
endif
()
#
endif()
#
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
#
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_fastgelu_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_relu_add_layernorm_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_add_fastgelu_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_streamk_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_fastgelu_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_gemm_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_add_relu_gemm_add_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_fastgelu_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
endif
()
#
endif()
rocm_install
(
TARGETS
${
PROFILER_EXECUTABLE
}
COMPONENT profiler
)
rocm_install
(
TARGETS
${
PROFILER_EXECUTABLE
}
COMPONENT profiler
)
profiler/src/profile_gemm_splitk.cpp
View file @
2e7c4ad7
...
@@ -122,94 +122,91 @@ int profile_gemm_splitk(int argc, char* argv[])
...
@@ -122,94 +122,91 @@ int profile_gemm_splitk(int argc, char* argv[])
return
pass
?
0
:
1
;
return
pass
?
0
:
1
;
};
};
#if 0
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
{
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}, F32{});
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}, F32{});
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
{
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}, F32{});
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}, F32{});
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
{
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}, F32{});
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}, F32{});
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
{
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}, F32{});
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}, F32{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
#endif
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
}
}
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
//
if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
//
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F16
{});
//
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
//
if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
//
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F16
{});
//
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
//
if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
//
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F16
{});
//
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
//
if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
//
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
//
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
//
if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_KN_MN)
{
//
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F8
{});
//
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F8{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
//
if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_NK_MN)
{
//
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F8
{});
//
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F8{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
//
if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_KN_MN)
{
//
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F8
{});
//
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F8{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
//
if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_NK_MN)
{
//
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F8
{});
//
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F8{});
}
//
}
#endif
#endif
else
return
0
;
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
return
1
;
}
}
}
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_gemm_splitk
);
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_gemm_splitk
);
script/cmake-ck-dev.sh
View file @
2e7c4ad7
...
@@ -8,10 +8,10 @@ MY_PROJECT_SOURCE=$1
...
@@ -8,10 +8,10 @@ MY_PROJECT_SOURCE=$1
cmake
\
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_CXX_FLAGS
=
"
--save-temps
-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
"gfx90
8;gfx90a;gfx940
"
\
-D
GPU_TARGETS
=
"gfx90
a
"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
${
MY_PROJECT_SOURCE
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment