Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1d82d465
Commit
1d82d465
authored
Oct 27, 2024
by
Jing Zhang
Browse files
add bfp16 support
parent
f03dda48
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
50 deletions
+74
-50
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+72
-20
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+1
-30
No files found.
example/01_gemm/CMakeLists.txt
View file @
1d82d465
...
@@ -30,6 +30,7 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
...
@@ -30,6 +30,7 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_v3
)
add_example_executable
(
example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp
)
add_example_executable
(
example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8_v3
)
add_example_executable
(
example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp
)
add_example_executable
(
example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_v3
)
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
1d82d465
...
@@ -11,18 +11,16 @@
...
@@ -11,18 +11,16 @@
namespace
ck
{
namespace
ck
{
//https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
{
{
const
int
LO
=
0x000f000f
;
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int
lo
=
amd_assembly_and_or_b32
(
q
,
LO
,
EX
);
int
lo
=
amd_assembly_and_or_b32
(
q
,
LO
,
EX
);
int
hi
=
amd_assembly_and_or_b32
(
q
,
HI
,
EX
);
int
hi
=
amd_assembly_and_or_b32
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0xE408E408
;
//-8
const
int
SUB
=
0xE408E408
;
//-8
const
int
MUL
=
0x2c002c00
;
// 1/16
const
int
MUL
=
0x2c002c00
;
// 1/16
const
int
ADD
=
0xd480d480
;
//-79
const
int
ADD
=
0xd480d480
;
//-79
...
@@ -40,17 +38,6 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
...
@@ -40,17 +38,6 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
__host__
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
__host__
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
{
{
#if 0
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f16 = ck::type_convert<ck::half_t>(x_l - 8);
auto h_f16 = ck::type_convert<ck::half_t>(x_h - 8);
return {h_f16, l_f16};
#elif
1
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
int
x_l
=
(
x_u8
&
0x0f
);
int
x_l
=
(
x_u8
&
0x0f
);
...
@@ -62,12 +49,51 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
...
@@ -62,12 +49,51 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
int
lo
=
(
x_l
|
x_h
)
|
EX
;
int
lo
=
(
x_l
|
x_h
)
|
EX
;
return
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
return
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
#else
int32_t
res
=
bit_cast
<
int8_t
>
(
q
);
return
bit_cast
<
half2_t
>
(
res
);
#endif
}
}
__host__
__device__
inline
bhalf4_t
pki4_to_bhalf4
(
pk_i4x2_t
i4s
)
{
uint32_t
q
=
bit_cast
<
uint16_t
>
(
i4s
);
uint32_t
i8s
=
(
q
&
0xf
)
|
(
q
&
0xf0
<<
4
)
|
(
q
&
0xf00
<<
8
)
|
(
q
&
0xf000
<<
12
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388616.
f
;
fp32_intermediates
[
1
]
-=
8388616.
f
;
fp32_intermediates
[
2
]
-=
8388616.
f
;
fp32_intermediates
[
3
]
-=
8388616.
f
;
vector_type
<
bhalf_t
,
4
>
res
;
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
__byte_perm
(
fp32_intermediates_casted
[
1
],
fp32_intermediates_casted
[
2
],
0x7632
);
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
}
__host__
__device__
inline
bhalf2_t
pki4_to_bhalf2
(
pk_i4_t
q
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
float
x_h
=
((
x_u8
&
0x0f
)
>>
0
)
-
8
;
float
x_l
=
((
x_u8
&
0xf0
)
>>
4
)
-
8
;
vector_type
<
bhalf_t
,
2
>
res
;
res
.
template
AsType
<
bhalf_t
>()(
Number
<
0
>
{})
=
type_convert
<
bhalf_t
>
(
x_l
);
res
.
template
AsType
<
bhalf_t
>()(
Number
<
1
>
{})
=
type_convert
<
bhalf_t
>
(
x_h
);
return
res
.
template
AsType
<
bhalf2_t
>()[
Number
<
0
>
{}];
}
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
...
@@ -102,6 +128,32 @@ struct PassThroughPack8
...
@@ -102,6 +128,32 @@ struct PassThroughPack8
#endif
#endif
}
}
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
{
#if 1
vector_type
<
bhalf_t
,
8
>
result
;
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#else
vector_type
<
bhalf_t
,
8
>
dst
;
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
pki4_to_bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
pki4_to_bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
2
>
{})
=
pki4_to_bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
3
>
{})
=
pki4_to_bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#endif
}
constexpr
const
static
bool
is_pack8_invocable
=
true
;
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
};
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
1d82d465
...
@@ -1147,36 +1147,7 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1147,36 +1147,7 @@ struct ThreadwiseTensorSliceTransfer_v4
});
});
}
}
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
&&
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
)
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
f8_t
>::
value
)
{
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
// DstData)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment