Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0f3b88bf
Commit
0f3b88bf
authored
Oct 11, 2024
by
Jing Zhang
Browse files
add a prototype of int4
parent
cfac9497
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
216 additions
and
42 deletions
+216
-42
CMakeLists.txt
CMakeLists.txt
+8
-8
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
+21
-7
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
+93
-0
example/01_gemm/gemm_xdl_fp16_v3.cpp
example/01_gemm/gemm_xdl_fp16_v3.cpp
+10
-10
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+3
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+13
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+36
-4
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+29
-11
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+1
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+1
-0
No files found.
CMakeLists.txt
View file @
0f3b88bf
...
...
@@ -543,7 +543,7 @@ ENDIF()
ENDFOREACH
()
add_custom_target
(
instances DEPENDS utility;
${
CK_DEVICE_INSTANCES
}
SOURCES
${
INSTANCE_FILES
}
)
add_subdirectory
(
library
)
#
add_subdirectory(library)
if
(
NOT GPU_ARCHS
)
rocm_package_setup_component
(
tests
...
...
@@ -556,20 +556,20 @@ if(NOT GPU_ARCHS)
PACKAGE_NAME examples
)
add_subdirectory
(
example
)
if
(
BUILD_TESTING
)
add_subdirectory
(
test
)
endif
()
#
if(BUILD_TESTING)
#
add_subdirectory(test)
#
endif()
endif
()
rocm_package_setup_component
(
profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
)
add_subdirectory
(
profiler
)
#
add_subdirectory(profiler)
if
(
CK_USE_CODEGEN
AND
(
GPU_TARGETS MATCHES
"gfx9"
OR GPU_ARCHS
))
add_subdirectory
(
codegen
)
endif
()
#
if(CK_USE_CODEGEN AND (GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
#
add_subdirectory(codegen)
#
endif()
#Create an interface target for the include only files and call it "composablekernels"
include
(
CMakePackageConfigHelpers
)
...
...
example/01_gemm/CMakeLists.txt
View file @
0f3b88bf
...
...
@@ -29,6 +29,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
add_example_executable
(
example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_v3
)
add_example_executable
(
example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8_v3
)
add_example_executable
(
example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_v3
)
...
...
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
View file @
0f3b88bf
...
...
@@ -5,8 +5,8 @@
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using
ADataType
=
ck
::
f
8
_t
;
using
BDataType
=
ck
::
hal
f_t
;
using
ADataType
=
ck
::
hal
f_t
;
using
BDataType
=
ck
::
f
8
_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
...
...
@@ -27,17 +27,31 @@ using DeviceGemmV2Instance =
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
#if 0
64,
16, 16,
64
,
16
,
8
,
256, 8
, 16,
16, 16,
1, 1,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
1, 1, S<1, 16, 1, 4>, 4,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
>
;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#else
128
,
16
,
32
,
128
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v1
>
;
#endif
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
...
...
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
0 → 100644
View file @
0f3b88bf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
pk_i4_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
inline
__host__
__device__
ck
::
half2_t
type_convert_packed_i4_to_half2
(
ck
::
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
);
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
auto
l_f16
=
ck
::
type_convert
<
ck
::
half_t
>
(
x_l
);
auto
h_f16
=
ck
::
type_convert
<
ck
::
half_t
>
(
x_h
);
return
{
l_f16
,
h_f16
};
}
struct
ElementwisePackedI4ToHalf2
{
__host__
__device__
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
pk_i4_t
&
x
)
const
{
y
=
type_convert_packed_i4_to_half2
(
x
);
}
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmV2Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffleV3
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
#if 0
64,
16, 16,
256, 8, 32,
16, 16,
1, 1,
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>;
#else
128
,
16
,
32
,
128
,
8
,
32
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v1
>
;
#endif
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
#include "run_gemm_example_v2.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_splitk_example
(
argc
,
argv
);
}
example/01_gemm/gemm_xdl_fp16_v3.cpp
View file @
0f3b88bf
...
...
@@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
...
...
@@ -27,17 +27,17 @@ using DeviceGemmV2Instance =
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
25
6
,
224
,
25
6
,
6
4
,
8
,
2
,
6
4
,
16
,
1
6
,
25
6
,
8
,
8
,
16
,
16
,
7
,
8
,
S
<
8
,
3
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
S
<
32
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
3
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
2
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intr
a
wave
,
ck
::
BlockGemmPipelineVersion
::
v
3
>
;
S
<
32
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Int
e
rwave
,
ck
::
BlockGemmPipelineVersion
::
v
2
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/01_gemm/run_gemm_example_v2.inc
View file @
0f3b88bf
...
...
@@ -228,6 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
bool
pass
=
true
;
#if 0
if
(
config
.
do_verification
)
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
...
@@ -257,11 +258,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
get_atol
<
CDataType
>
());
#endif
}
#endif
if
(
config
.
time_kernel
)
{
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
0
,
5
,
1
0
,
true
,
4
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
0
,
20
,
5
0
,
true
,
50
});
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
0f3b88bf
...
...
@@ -22,6 +22,19 @@ struct PassThroughPack2
auto
t
=
type_convert
<
float2_t
>
(
x
);
y
=
type_convert
<
half2_t
>
(
t
);
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
pk_i4_t
&
x
)
const
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
auto
l_f16
=
ck
::
type_convert
<
ck
::
half_t
>
(
x_l
);
auto
h_f16
=
ck
::
type_convert
<
ck
::
half_t
>
(
x_h
);
y
=
{
l_f16
,
h_f16
};
}
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
0f3b88bf
...
...
@@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__device__
constexpr
ThreadwiseTensorSliceTransfer_v4
(
const
Index
&
src_ref_idx
)
:
src_ref_coord_
(
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
{
...
...
@@ -1015,6 +1022,8 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
"wrong! Not divisible"
);
static_assert
(
!
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
&&
(
SrcScalarPerVector
==
1
)),
"pk data N cannot be 1"
);
}
template
<
typename
SrcRefToOriginDisplacement
,
...
...
@@ -1109,7 +1118,7 @@ struct ThreadwiseTensorSliceTransfer_v4
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
/
PackedSize
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
...
...
@@ -1120,7 +1129,7 @@ struct ThreadwiseTensorSliceTransfer_v4
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
(),
is_src_valid
);
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
()
/
PackedSize
,
is_src_valid
);
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
...
...
@@ -1129,11 +1138,34 @@ struct ThreadwiseTensorSliceTransfer_v4
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
/
PackedSize
>
{}];
});
}
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
PackedSize
>::
type
;
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
1
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
PackedSize
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack2
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
SrcScalarPerVector
%
2
==
0
)
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
0f3b88bf
...
...
@@ -31,8 +31,8 @@ template <typename SliceLengths,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarPerVector
_
,
index_t
DstScalarPerVector
_
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
...
...
@@ -55,6 +55,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
static
constexpr
auto
SrcScalarPerVector
=
Number
<
SrcScalarPerVector_
/
PackedSize
>
{};
static
constexpr
auto
DstScalarPerVector
=
Number
<
DstScalarPerVector_
/
PackedSize
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
...
...
@@ -67,6 +78,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_
(
src_element_op
),
dst_element_op_
(
dst_element_op
)
{
static_assert
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
remove_cvref_t
<
DstData
>>
,
"SrcData != DstData"
);
static_assert
(
!
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
&&
(
SrcScalarPerVector
==
1
||
DstScalarPerVector
==
1
)),
"pk data N cannot be 1"
);
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
...
@@ -95,11 +108,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
*
PackedSize
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
static_assert
(
SliceLengths
::
At
(
SrcVectorDim
)
%
SrcScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
SrcVectorDim
)
%
(
SrcScalarPerVector
*
PackedSize
)
==
0
,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"
);
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
...
...
@@ -181,7 +194,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
src_vector_t
=
typename
src_vector_type
::
type
;
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
true
)};
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
()
/
PackedSize
,
true
)};
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
...
...
@@ -279,7 +292,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// OOB Check
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
*
PackedSize
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
...
@@ -368,9 +381,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst
<
SrcVectorDim
,
SrcScalarPerVector
,
SrcScalarPerVector
*
PackedSize
,
DstVectorDim
,
DstScalarPerVector
>
{},
DstScalarPerVector
*
PackedSize
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
...
...
@@ -410,7 +423,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}
else
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
constexpr
auto
packed_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
PackedSize
>
{},
Number
<
nDim
>
{});
constexpr
auto
packed_access_lengths
=
SliceLengths
{}
/
packed_per_access
;
static_ford
<
decltype
(
packed_access_lengths
)
>
{}([
&
](
auto
idx
)
{
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
}
...
...
@@ -438,7 +456,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// src scalar per access on each dim
// TODO: don't use this
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
*
PackedSize
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
...
@@ -532,7 +550,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// copy data from dst_vector_container to dst_buf
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
dst_coord_
.
GetOffset
()
/
PackedSize
,
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
0f3b88bf
...
...
@@ -429,7 +429,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
uint8
_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
pk_i4
_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
...
...
include/ck/utility/data_type.hpp
View file @
0f3b88bf
...
...
@@ -12,6 +12,7 @@ using half_t = _Float16;
using
int4_t
=
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
pk_i4_t
=
unsigned
char
;
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment