Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d3721152
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "ea9685a93887989d67fb3b6f48ca338b3d2e5551"
Commit
d3721152
authored
Mar 16, 2023
by
wunhuang
Browse files
Change j matrix type from unsigned short to int
parent
55057f09
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
44 deletions
+46
-44
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
+5
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
+2
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+32
-32
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
d3721152
...
...
@@ -95,14 +95,15 @@ __global__ void
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
unsigned
short
*
z_matrix_ptr
=
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
//
unsigned short* z_matrix_ptr =
//
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
//
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
d3721152
...
...
@@ -95,14 +95,15 @@ __global__ void
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
unsigned
short
*
z_matrix_ptr
=
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
//
unsigned short* z_matrix_ptr =
//
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
//
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
d3721152
...
...
@@ -1236,7 +1236,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
YGradGridDesc_O0_M_O1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
unsigned
shor
t
*
__restrict__
p_z_grid
,
in
t
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
...
...
@@ -1552,7 +1552,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushor
t
,
in
t
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
View file @
d3721152
...
...
@@ -1146,7 +1146,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
YGradGridDesc_M0_O_M1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_q_grid
,
const
DataType
*
__restrict__
p_k_grid
,
unsigned
shor
t
*
__restrict__
p_z_grid
,
in
t
*
__restrict__
p_z_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
...
...
@@ -1484,7 +1484,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushor
t
,
in
t
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
include/ck/utility/data_type.hpp
View file @
d3721152
...
...
@@ -974,38 +974,38 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
uint32_t
int32
;
}
u
=
{
x
};
if
(
~
u
.
int32
&
0x7f800000
)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u
.
int32
+=
0x7fff
+
((
u
.
int32
>>
16
)
&
1
);
// Round to nearest, round to even
}
else
if
(
u
.
int32
&
0xffff
)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u
.
int32
|=
0x10000
;
// Preserve signaling NaN
}
//
if(~u.int32 & 0x7f800000)
//
{
//
// When the exponent bits are not all 1s, then the value is zero, normal,
//
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
//
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
//
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
//
// least significant bits of the float mantissa are greater than 0x8000,
//
// or if they are equal to 0x8000 and the least significant bit of the
//
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
//
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
//
// has the value 0x7f, then incrementing it causes it to become 0x00 and
//
// the exponent is incremented by one, which is the next higher FP value
//
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
//
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
//
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
//
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
//
// incrementing it causes it to become an exponent of 0xFF and a mantissa
//
// of 0x00, which is Inf, the next higher value to the unrounded value.
//
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
//
}
//
else if(u.int32 & 0xffff)
//
{
//
// When all of the exponent bits are 1, the value is Inf or NaN.
//
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
//
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
//
// bit being 1. Signaling NaN is indicated by the most significant
//
// mantissa bit being 0 but some other bit(s) being 1. If any of the
//
// lower 16 bits of the mantissa are 1, we set the least significant bit
//
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
//
// the bloat16's mantissa bits are all 0.
//
u.int32 |= 0x10000; // Preserve signaling NaN
//
}
return
uint16_t
(
u
.
int32
>>
16
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment