Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c13366af
Commit
c13366af
authored
Oct 18, 2024
by
Jing Zhang
Browse files
add fast pki4 to half conversion
parent
24e18ae8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
91 additions
and
12 deletions
+91
-12
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+46
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+14
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+5
-3
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+15
-0
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+8
-6
No files found.
example/01_gemm/run_gemm_example_v2.inc
View file @
c13366af
...
@@ -139,7 +139,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -139,7 +139,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
{
case
0
:
case
0
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
0x
11
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
0x
99
});
break
;
break
;
case
1
:
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
c13366af
...
@@ -7,11 +7,57 @@
...
@@ -7,11 +7,57 @@
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/amd_inline_asm.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
//int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
//int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int
lo
=
(
q
&
LO
)
|
EX
;
int
hi
=
(
q
&
HI
)
|
EX
;
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0xE408E408
;
//-8
const
int
MUL
=
0x2c002c00
;
//1/16
const
int
ADD
=
0xd480d480
;
//-79
vector_type
<
half_t
,
4
>
res
;
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
amd_assembly_pk_fma_f16
(
bit_cast
<
half2_t
>
(
hi
),
bit_cast
<
half2_t
>
(
MUL
),
bit_cast
<
half2_t
>
(
ADD
));
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
struct
PassThroughPack8
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
{
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
}
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
struct
PassThroughPack2
struct
PassThroughPack2
{
{
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
c13366af
...
@@ -387,6 +387,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -387,6 +387,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
else
else
{
{
#if 1
// not pad N or K
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
b_grid_desc_nraw_kraw
,
...
@@ -394,6 +395,19 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -394,6 +395,19 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_pass_through_transform
(
N
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
#else
const
index_t
N0
=
N
/
NPerBlock
;
const
index_t
N1
=
NPerBlock
;
const
auto
b_grid_desc_n0_bk0_n1_bk1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N0
,
BK0
,
N1
,
BK1Value
));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n0_bk0_n1_bk1
,
make_tuple
(
make_pass_through_transform
(
BK0
),
make_merge_transform
(
make_tuple
(
N0
,
N1
)),
make_pass_through_transform
(
BK1Value
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#endif
return
b_grid_desc_bk0_n_bk1
;
return
b_grid_desc_bk0_n_bk1
;
}
}
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
c13366af
...
@@ -1150,12 +1150,14 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1150,12 +1150,14 @@ struct ThreadwiseTensorSliceTransfer_v4
// DstData)
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
PackedSize
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
1
>::
type
;
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
4
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack
2
{}(
ck
::
tensor_operation
::
element_wise
::
PassThroughPack
8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
});
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
c13366af
...
@@ -11,6 +11,21 @@
...
@@ -11,6 +11,21 @@
namespace
ck
{
namespace
ck
{
inline
__device__
half2_t
amd_assembly_pk_fma_f16
(
half2_t
a
,
half2_t
b
,
half2_t
c
)
{
half2_t
d
;
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3;
\n
"
:
"=v"
(
d
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
c
));
return
d
;
}
inline
__device__
half2_t
amd_assembly_pk_add_f16
(
half2_t
a
,
half2_t
b
)
{
half2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2;
\n
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
));
return
c
;
}
// c0 += inner_product(a, b0)
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c1 += inner_product(a, b1)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
...
...
include/ck/utility/amd_xdlops.hpp
View file @
c13366af
...
@@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
...
@@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
//
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
//
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
}
}
};
};
...
...
include/ck/utility/data_type.hpp
View file @
c13366af
...
@@ -1054,12 +1054,14 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type;
...
@@ -1054,12 +1054,14 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
// u8
// u8
// i8
// i8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
//using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
//using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using
uint8x8_t
=
typename
vector_type
<
uint8_t
,
8
>::
type
;
//using uint8x8_t = typename vector_type<uint8_t, 8>::type;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
//using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
//using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
//using uint8x64_t = typename vector_type<uint8_t, 64>::type;
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment