Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ca9b152d
Unverified
Commit
ca9b152d
authored
Oct 13, 2023
by
ltqin
Committed by
GitHub
Oct 13, 2023
Browse files
Merge pull request #988 from ROCmSoftwarePlatform/mha-train-develop-tinyfix
Two tiny updates
parents
3f4eae1d
aef8ea3c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
6 deletions
+75
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
+29
-6
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+46
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
View file @
ca9b152d
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <cstring>
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
@@ -687,12 +688,34 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
...
@@ -687,12 +688,34 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
some_has_main_k_block_loop
|=
y
;
some_has_main_k_block_loop
|=
y
;
}
}
hipGetErrorString
(
hipStreamCaptureStatus
status
=
hipStreamCaptureStatusNone
;
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
group_kernel_args_
.
data
(),
HIP_CHECK_ERROR
(
hipStreamIsCapturing
(
stream_config
.
stream_id_
,
&
status
));
arg
.
group_kernel_args_
.
size
()
*
sizeof
(
GroupKernelArg
),
hipMemcpyHostToDevice
,
if
(
status
==
hipStreamCaptureStatusActive
)
stream_config
.
stream_id_
));
{
size_t
copy_size
=
arg
.
group_kernel_args_
.
size
()
*
sizeof
(
GroupKernelArg
);
// ToDO: when to release this memory buffer?
char
*
persistent_ptr
=
new
char
[
copy_size
];
(
void
)
std
::
memcpy
(
persistent_ptr
,
arg
.
group_kernel_args_
.
data
(),
copy_size
);
HIP_CHECK_ERROR
(
hipMemcpyAsync
(
arg
.
p_workspace_
,
persistent_ptr
,
copy_size
,
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
}
else
{
HIP_CHECK_ERROR
(
hipMemcpyAsync
(
arg
.
p_workspace_
,
arg
.
group_kernel_args_
.
data
(),
arg
.
group_kernel_args_
.
size
()
*
sizeof
(
GroupKernelArg
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
}
float
ave_time
=
0
;
float
ave_time
=
0
;
...
...
include/ck/utility/type_convert.hpp
View file @
ca9b152d
...
@@ -31,6 +31,51 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
...
@@ -31,6 +31,51 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return
u
.
fp32
;
return
u
.
fp32
;
}
}
#ifdef USE_RTN_BF16_CONVERT
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
}
#else
// convert fp32 to bfp16
// convert fp32 to bfp16
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
...
@@ -43,6 +88,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
...
@@ -43,6 +88,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return
uint16_t
(
u
.
int32
>>
16
);
return
uint16_t
(
u
.
int32
>>
16
);
}
}
#endif
// convert bfp16 to fp16 via fp32
// convert bfp16 to fp16 via fp32
template
<
>
template
<
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment