Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
da0bb989
Commit
da0bb989
authored
Mar 07, 2023
by
ltqin
Browse files
add bfloat16 to xdlops
parent
36ca02f3
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
97 additions
and
2 deletions
+97
-2
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt2.cpp
...softmax_gemm/batched_multihead_attention_backward_pt2.cpp
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+6
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+22
-1
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+13
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+55
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt2.cpp
View file @
da0bb989
...
@@ -62,7 +62,7 @@ using QKVElementOp = PassThrough;
...
@@ -62,7 +62,7 @@ using QKVElementOp = PassThrough;
using
YElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
F16
;
using
DataType
=
F16
;
using
GemmDataType
=
F16
;
using
GemmDataType
=
B
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
da0bb989
...
@@ -40,6 +40,12 @@ struct PassThrough
...
@@ -40,6 +40,12 @@ struct PassThrough
y
=
x
;
y
=
x
;
}
}
template
<
>
__host__
__device__
void
operator
()
<
bfloat16_t
,
bfloat16_t
>
(
bfloat16_t
&
y
,
const
bfloat16_t
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
int32_t
,
int32_t
>
(
int32_t
&
y
,
const
int32_t
&
x
)
const
__host__
__device__
void
operator
()
<
int32_t
,
int32_t
>
(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
{
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
da0bb989
...
@@ -524,6 +524,27 @@ struct MfmaSelector
...
@@ -524,6 +524,27 @@ struct MfmaSelector
#endif
#endif
}
}
template
<
>
static
constexpr
auto
GetMfma
<
bfloat16_t
,
32
,
32
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_32x32x4bf16
;
#endif
}
template
<
>
static
constexpr
auto
GetMfma
<
bfloat16_t
,
16
,
16
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_16x16x8bf16
;
#endif
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
{
...
@@ -735,7 +756,7 @@ struct XdlopsGemm
...
@@ -735,7 +756,7 @@ struct XdlopsGemm
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
bfloat16_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
,
is_same
<
base_type
,
int8_t
>::
value
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
...
...
include/ck/utility/amd_xdlops.hpp
View file @
da0bb989
...
@@ -215,6 +215,13 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
...
@@ -215,6 +215,13 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bfloat16x4_t
&
reg_a
,
const
bfloat16x4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
...
@@ -243,6 +250,12 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32>
...
@@ -243,6 +250,12 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32>
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bfloat16x2_t
&
reg_a
,
const
bfloat16x2_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
...
...
include/ck/utility/data_type.hpp
View file @
da0bb989
...
@@ -13,6 +13,8 @@ using half_t = _Float16;
...
@@ -13,6 +13,8 @@ using half_t = _Float16;
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
#endif
#endif
using
bfloat16_t
=
int16_t
;
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
vector_type
;
struct
vector_type
;
...
@@ -119,6 +121,13 @@ struct scalar_type<bhalf_t>
...
@@ -119,6 +121,13 @@ struct scalar_type<bhalf_t>
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
template
<
>
struct
scalar_type
<
bfloat16_t
>
{
using
type
=
bfloat16_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
template
<
>
struct
scalar_type
<
int32_t
>
struct
scalar_type
<
int32_t
>
{
{
...
@@ -926,6 +935,13 @@ using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
...
@@ -926,6 +935,13 @@ using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using
bhalf32_t
=
typename
vector_type
<
bhalf_t
,
32
>::
type
;
using
bhalf32_t
=
typename
vector_type
<
bhalf_t
,
32
>::
type
;
using
bhalf64_t
=
typename
vector_type
<
bhalf_t
,
64
>::
type
;
using
bhalf64_t
=
typename
vector_type
<
bhalf_t
,
64
>::
type
;
// bfloat16_t
using
bfloat16x2_t
=
typename
vector_type
<
bfloat16_t
,
2
>::
type
;
using
bfloat16x4_t
=
typename
vector_type
<
bfloat16_t
,
4
>::
type
;
using
bfloat16x8_t
=
typename
vector_type
<
bfloat16_t
,
8
>::
type
;
using
bfloat16x16_t
=
typename
vector_type
<
bfloat16_t
,
16
>::
type
;
using
bfloat16x32_t
=
typename
vector_type
<
bfloat16_t
,
32
>::
type
;
using
bfloat16x64_t
=
typename
vector_type
<
bfloat16_t
,
64
>::
type
;
// i32
// i32
using
int32x2_t
=
typename
vector_type
<
int32_t
,
2
>::
type
;
using
int32x2_t
=
typename
vector_type
<
int32_t
,
2
>::
type
;
using
int32x4_t
=
typename
vector_type
<
int32_t
,
4
>::
type
;
using
int32x4_t
=
typename
vector_type
<
int32_t
,
4
>::
type
;
...
@@ -1023,6 +1039,45 @@ inline __host__ __device__ bhalf_t type_convert<bhalf_t, half_t>(half_t x)
...
@@ -1023,6 +1039,45 @@ inline __host__ __device__ bhalf_t type_convert<bhalf_t, half_t>(half_t x)
return
uint16_t
(
u
.
int32
>>
16
);
return
uint16_t
(
u
.
int32
>>
16
);
}
}
// convert bfp16 to fp32
template
<
>
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bfloat16_t
>
(
bfloat16_t
x
)
{
union
{
uint32_t
int32
;
float
fp32
;
}
u
=
{
uint32_t
(
x
)
<<
16
};
return
u
.
fp32
;
}
// convert fp32 to bfp16
template
<
>
inline
__host__
__device__
constexpr
bfloat16_t
type_convert
<
bfloat16_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert fp16 to bf16
template
<
>
inline
__host__
__device__
bfloat16_t
type_convert
<
bfloat16_t
,
half_t
>
(
half_t
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
static_cast
<
float
>
(
x
)};
return
uint16_t
(
u
.
int32
>>
16
);
}
template
<
>
template
<
>
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
half2_t
>
(
half2_t
x
)
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
half2_t
>
(
half2_t
x
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment