Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e61489e2
Commit
e61489e2
authored
Dec 03, 2023
by
Jing Zhang
Browse files
fixed
parent
0fd1d636
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
353 additions
and
105 deletions
+353
-105
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+6
-6
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+1
-5
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+262
-10
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+5
-3
library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
.../tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
+22
-21
profiler/include/profiler/profile_gemm_impl.hpp
profiler/include/profiler/profile_gemm_impl.hpp
+2
-2
profiler/include/profiler/profile_gemm_splitk_impl.hpp
profiler/include/profiler/profile_gemm_splitk_impl.hpp
+2
-2
profiler/src/profile_gemm_splitk.cpp
profiler/src/profile_gemm_splitk.cpp
+46
-51
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-1
No files found.
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
View file @
e61489e2
...
@@ -30,20 +30,20 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -30,20 +30,20 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F
16
;
using
ADataType
=
F
8
;
using
BDataType
=
F8
;
using
BDataType
=
F8
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
// clang-format off
// clang-format off
...
@@ -51,7 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
...
@@ -51,7 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
16
,
256
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F16
,
ck
::
PipelineVersion
::
v1
,
ck
::
LoopScheduler
::
Interwave
>
;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 2, S<1, 16, 1, 8>, 8, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Interwave>;
// clang-format on
// clang-format on
#include "run_splitK_gemm_example.inc"
#include "run_splitK_gemm_example.inc"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
e61489e2
...
@@ -401,7 +401,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -401,7 +401,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
return
math
::
max
(
a_block_space_size
*
sizeof
(
FloatA
)
+
b_block_space_size
*
sizeof
(
FloatB
),
return
math
::
max
(
a_block_space_size
*
sizeof
(
FloatA
)
+
b_block_space_size
*
sizeof
(
FloatB
),
c_block_size
*
sizeof
(
Float
C
));
c_block_size
*
sizeof
(
Float
Acc
));
}
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
...
@@ -834,8 +834,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -834,8 +834,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
auto
*
p_a_block
=
static_cast
<
FloatA
*>
(
p_shared_block
);
FloatA
*
p_a_block
=
static_cast
<
FloatA
*>
(
p_shared_block
);
auto
*
p_b_block
=
static_cast
<
FloatB
*>
(
p_shared_block
)
+
a_block_space_size
;
FloatB
*
p_b_block
=
static_cast
<
FloatB
*>
(
p_shared_block
)
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
...
@@ -892,7 +892,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -892,7 +892,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
C
*>
(
p_shared_block
),
static_cast
<
Float
Acc
*>
(
p_shared_block
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
...
@@ -943,7 +943,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -943,7 +943,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// VGPR to LDS
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
Float
C
,
Float
Acc
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -983,7 +983,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -983,7 +983,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Float
C
,
// typename SrcData,
Float
Acc
,
// typename SrcData,
FloatC
,
// typename DstData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
e61489e2
...
@@ -1160,8 +1160,6 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1160,8 +1160,6 @@ struct ThreadwiseTensorSliceTransfer_v4
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
});
}
}
#if 0
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
...
@@ -1171,15 +1169,13 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1171,15 +1169,13 @@ struct ThreadwiseTensorSliceTransfer_v4
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
});
#endif
// copy data from dst_tmp_vector into dst_buf
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
});
});
});
}
}
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
e61489e2
...
@@ -278,7 +278,7 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
...
@@ -278,7 +278,7 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.f32"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.f32"
);
// buffer atomic-
add
fp
32
// buffer atomic-
max
fp
64
__device__
double
__device__
double
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
double
vdata
,
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
double
vdata
,
int32x4_t
rsrc
,
// dst_wave_buffer_resource
int32x4_t
rsrc
,
// dst_wave_buffer_resource
...
@@ -420,10 +420,124 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
...
@@ -420,10 +420,124 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
if
constexpr
(
is_same
<
T
,
float
>::
value
)
// fp32
auto
raw_data
=
amd_buffer_load_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
{
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
);
if
constexpr
(
N
==
1
)
return
bit_cast
<
r_t
>
(
raw_data
);
{
return
llvm_amdgcn_raw_buffer_load_fp32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
;
tmp
.
AsType
<
float4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
float4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
float8_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
// fp16
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_fp16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
// use fp32 load to mimic fp16 load
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
half8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
// bf16
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
bhalf8_t
>
(
tmp
);
}
}
else
// other datatype
{
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
auto
raw_data
=
amd_buffer_load_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
);
return
bit_cast
<
r_t
>
(
raw_data
);
}
}
}
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
...
@@ -542,12 +656,150 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
...
@@ -542,12 +656,150 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
int8_t
,
sizeof
(
T
)
*
N
>::
type
;
if
constexpr
(
is_same
<
T
,
float
>::
value
)
// fp32
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp32
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
// fp16
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
#if 0
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
// bf16
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
bhalf_t
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
bhalf_t
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
{
using
r_t
=
typename
vector_type
<
int8_t
,
sizeof
(
T
)
*
N
>::
type
;
amd_buffer_store_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
bit_cast
<
r_t
>
(
src_thread_data
),
amd_buffer_store_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
bit_cast
<
r_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
);
dst_wave_addr_offset
);
}
}
}
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
...
include/ck/utility/type_convert.hpp
View file @
e61489e2
...
@@ -366,10 +366,12 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
...
@@ -366,10 +366,12 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
// use native conversion to float and convert to fp16
// return type_convert<half_t>(type_convert<float>(x));
// return type_convert<half_t>(type_convert<float>(x));
return
static_cast
<
half_t
>
(
x
);
// return static_cast<half_t>(x);
return
static_cast
<
half_t
>
(
bit_cast
<
int8_t
>
(
x
));
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
// constexpr bool negative_zero_nan = true;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
// return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
return
static_cast
<
half_t
>
(
bit_cast
<
int8_t
>
(
x
));
#endif
#endif
}
}
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
View file @
e61489e2
set
(
GEMM_SPLITK_INSTANCES
)
set
(
GEMM_SPLITK_INSTANCES
)
list
(
APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
list
(
APPEND GEMM_SPLITK_INSTANCES
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
#device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp
#device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp
#device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp
#device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp
#device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
)
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
)
add_instance_library
(
device_gemm_splitk_instance
${
GEMM_SPLITK_INSTANCES
}
)
add_instance_library
(
device_gemm_splitk_instance
${
GEMM_SPLITK_INSTANCES
}
)
profiler/include/profiler/profile_gemm_impl.hpp
View file @
e61489e2
...
@@ -236,8 +236,8 @@ int profile_gemm_impl(int do_verification,
...
@@ -236,8 +236,8 @@ int profile_gemm_impl(int do_verification,
{
{
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
float
avg_time
=
StreamConfig
{
nullptr
,
time_kernel
,
0
,
50
,
200
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
...
...
profiler/include/profiler/profile_gemm_splitk_impl.hpp
View file @
e61489e2
...
@@ -200,8 +200,8 @@ bool profile_gemm_splitk_impl(int do_verification,
...
@@ -200,8 +200,8 @@ bool profile_gemm_splitk_impl(int do_verification,
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
float
ave_time
=
StreamConfig
{
nullptr
,
time_kernel
,
0
,
50
,
200
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
...
...
profiler/src/profile_gemm_splitk.cpp
View file @
e61489e2
...
@@ -122,96 +122,91 @@ int profile_gemm_splitk(int argc, char* argv[])
...
@@ -122,96 +122,91 @@ int profile_gemm_splitk(int argc, char* argv[])
return
pass
?
0
:
1
;
return
pass
?
0
:
1
;
};
};
#if 0
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
{
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}, F32{});
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}, F32{});
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
{
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}, F32{});
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}, F32{});
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
{
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}, F32{});
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}, F32{});
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
{
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}, F32{});
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}, F32{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
#endif
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
}
}
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
//
if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
//
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F16
{});
//
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
//
if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
//
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F16
{});
//
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
//
if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
//
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F16
{});
//
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
//
if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
//
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
//
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
}
//
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F16
{});
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
}
}
#if 0
// if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_KN_MN)
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_KN_MN)
//{
{
// return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F8{});
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F8{});
//}
}
// if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_NK_MN)
//{
{
// return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F8{});
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F8{});
//}
}
// if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_KN_MN)
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_KN_MN)
//{
{
// return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F8{});
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F8{});
//}
}
// if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_NK_MN)
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_NK_MN)
//{
{
// return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F8{});
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F8{});
//}
}
#endif
#endif
#endif
else
return
0
;
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
return
1
;
}
}
}
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_gemm_splitk
);
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_gemm_splitk
);
script/cmake-ck-dev.sh
View file @
e61489e2
...
@@ -11,7 +11,7 @@ cmake
...
@@ -11,7 +11,7 @@ cmake
-D
CMAKE_CXX_FLAGS
=
"--save-temps -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_CXX_FLAGS
=
"--save-temps -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
"gfx9
42
"
\
-D
GPU_TARGETS
=
"gfx9
0a
"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
${
MY_PROJECT_SOURCE
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment