Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5e319fb8
"scripts/vscode:/vscode.git/clone" did not exist on "4075677621f3be941f205cac669d37b8db3a8851"
Commit
5e319fb8
authored
May 17, 2023
by
guangzlu
Browse files
modified dropout in fwd
parent
762cbddf
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
186 additions
and
9 deletions
+186
-9
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+111
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+18
-6
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+57
-3
No files found.
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
5e319fb8
...
...
@@ -50,6 +50,41 @@ struct BlockwiseDropout
});
}
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
),
element_global_1d_id
);
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ZThreadBuffer
&
z_thread_buf
)
...
...
@@ -86,6 +121,82 @@ struct BlockwiseDropout
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
),
element_global_1d_id
+
philox_calls
*
8
);
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
,
typename
N0
,
typename
Offset
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
N0
{}.
value
;
int
philox_calls
=
tmp_size
/
8
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_8x16
((
tmp
+
i
*
8
),
element_global_1d_id
);
}
block_sync_lds
();
constexpr
auto
iOffset
=
Number
<
tmp_size
>
{}
*
Offset
{};
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
in_thread_buf
(
i
+
iOffset
)
=
execute_dropout
(
tmp
[
i
.
value
]
<=
p_dropout_16bits
,
in_thread_buf
(
i
+
iOffset
));
z_thread_buf
(
i
)
=
tmp
[
i
.
value
];
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
5e319fb8
...
...
@@ -79,7 +79,9 @@ __global__ void
const
ushort
p_dropout_in_16bits
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
NRaw
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -112,8 +114,8 @@ __global__ void
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
p_lse_grid
==
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -131,6 +133,9 @@ __global__ void
p_dropout_in_16bits
,
p_dropout_rescale
,
ph
,
g_idx
,
MRaw
,
NRaw
,
i
);
}
}
...
...
@@ -141,8 +146,8 @@ __global__ void
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
p_lse_grid
==
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -160,6 +165,9 @@ __global__ void
p_dropout_in_16bits
,
p_dropout_rescale
,
ph
,
g_idx
,
MRaw
,
NRaw
,
0
);
}
#else
...
...
@@ -644,6 +652,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
{
is_lse_storing_
=
false
;
}
// std::cout << "batch_count_: " << batch_count_ << std::endl;
}
void
Print
()
const
...
...
@@ -803,7 +813,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
5e319fb8
...
...
@@ -447,6 +447,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
ushort
p_dropout_in_16bits
,
FloatGemmAcc
p_dropout_rescale
,
ck
::
philox
&
ph
,
const
index_t
g_idx
,
const
index_t
MRaw
,
const
index_t
NRaw
,
const
index_t
block_idx_m
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -986,6 +989,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
// if(get_thread_global_1d_id()==0){
// printf("m_global is %d \n", m_global);
// printf("n_global is %d \n", n_global);
//}
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
acc_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
...
@@ -1012,6 +1021,51 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if
constexpr
(
IsDropout
)
// dropout
{
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id
=
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
// if(get_thread_global_1d_id()==1){
// printf("at 1 m_global is %d \n", m_global);
// printf("at 1 n_global is %d \n", n_global);
// printf("at 1 global_elem_id is %d \n", global_elem_id);
// }
// save z to global
if
(
p_z_grid
)
...
...
@@ -1022,7 +1076,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false
,
decltype
(
n0
),
decltype
(
i
)>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
acc_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
...
@@ -1046,7 +1100,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// ignore = z_grid_buf;
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
false
>(
acc_thread_buf
,
ph
);
acc_thread_buf
,
ph
,
global_elem_id
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment