Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1274ec3e
Commit
1274ec3e
authored
Jan 14, 2024
by
Tri Dao
Browse files
Move dropout to a separate file (dropout.h)
parent
10dad612
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
112 additions
and
95 deletions
+112
-95
csrc/flash_attn/src/dropout.h
csrc/flash_attn/src/dropout.h
+91
-0
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+9
-10
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+12
-15
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+0
-70
No files found.
csrc/flash_attn/src/dropout.h
0 → 100644
View file @
1274ec3e
#pragma once
#include "philox.cuh"
#include "utils.h"
namespace
flash
{
struct
Dropout
{
const
unsigned
long
long
seed
,
offset
;
const
uint8_t
p_dropout_in_uint8_t
;
inline
__device__
Dropout
(
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
,
const
uint8_t
p_dropout_in_uint8_t
,
const
int
bid
,
const
int
hid
,
const
int
tid
,
const
int
nheads
)
:
seed
(
seed
)
,
offset
(
offset
+
(
bid
*
nheads
+
hid
)
*
32
+
tid
%
32
)
,
p_dropout_in_uint8_t
(
p_dropout_in_uint8_t
)
{
}
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor_
,
int
block_row_start
,
int
block_col_start
,
int
block_row_stride
)
{
// tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
tensor
=
make_tensor
(
tensor_
.
data
(),
flash
::
convert_layout_rowcol_dropout
(
tensor_
.
layout
()));
// tensor has shape (8, MMA_M, MMA_N / 2)
using
T
=
typename
Engine
::
value_type
;
auto
encode_dropout
=
[](
bool
keep
,
T
val
)
{
return
keep
?
val
:
(
encode_dropout_in_sign_bit
?
-
val
:
T
(
0
));
};
static_assert
(
decltype
(
size
<
2
>
(
tensor
))
::
value
%
2
==
0
);
const
uint16_t
p_dropout_8bit_in_uint16_t
=
uint16_t
(
p_dropout_in_uint8_t
);
const
uint32_t
p_dropout_8bit_in_uint32_t
=
(
uint32_t
(
p_dropout_8bit_in_uint16_t
)
<<
16
)
|
uint32_t
(
p_dropout_8bit_in_uint16_t
);
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tensor
);
++
m
,
block_row_start
+=
block_row_stride
)
{
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
#pragma unroll
for
(
int
n
=
0
;
n
<
size
<
2
>
(
tensor
)
/
2
;
++
n
,
++
rowcol
.
y
)
{
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4
random_uint4
=
flash
::
philox
(
seed
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
offset
);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t
(
&
rnd_8
)[
16
]
=
reinterpret_cast
<
uint8_t
(
&
)[
16
]
>
(
random_uint4
);
// Special implementation for 16-bit types: we duplicate the threshold to the
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if
(
!
encode_dropout_in_sign_bit
&&
(
std
::
is_same
<
T
,
cutlass
::
half_t
>::
value
||
std
::
is_same
<
T
,
cutlass
::
bfloat16_t
>::
value
))
{
uint16_t
rnd_16
[
16
];
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
rnd_16
[
i
]
=
uint16_t
(
rnd_8
[
i
]);
}
uint32_t
(
&
rnd_32
)[
8
]
=
reinterpret_cast
<
uint32_t
(
&
)[
8
]
>
(
rnd_16
);
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
Tensor
tensor_uint32
=
recast
<
uint32_t
>
(
tensor
(
_
,
m
,
n
*
2
+
j
));
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
uint32_t
mask
;
asm
volatile
(
"set.le.u32.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
mask
)
:
"r"
(
rnd_32
[
j
*
4
+
i
]),
"r"
(
p_dropout_8bit_in_uint32_t
));
tensor_uint32
(
i
)
&=
mask
;
}
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
else
{
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
tensor
(
i
,
m
,
n
*
2
+
j
)
=
encode_dropout
(
rnd_8
[
j
*
8
+
i
]
<=
p_dropout_in_uint8_t
,
tensor
(
i
,
m
,
n
*
2
+
j
));
}
Tensor
tensor_uint32
=
recast
<
uint32_t
>
(
tensor
(
_
,
m
,
n
*
2
+
j
));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
}
}
}
};
}
// namespace flash
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
1274ec3e
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "kernel_traits.h"
#include "kernel_traits.h"
#include "utils.h"
#include "utils.h"
#include "softmax.h"
#include "softmax.h"
#include "dropout.h"
#include "alibi.h"
#include "alibi.h"
...
@@ -796,8 +797,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -796,8 +797,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
cute
::
copy
(
smem_tiled_copy_KV
,
tdPsV
,
tdPrV_copy_view
);
cute
::
copy
(
smem_tiled_copy_KV
,
tdPsV
,
tdPrV_copy_view
);
}
}
auto
seed
=
params
.
rng_state
[
0
];
flash
::
Dropout
dropout
(
params
.
rng_state
[
0
],
params
.
rng_state
[
1
],
params
.
p_dropout_in_uint8_t
,
auto
offset
=
params
.
rng_state
[
1
]
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
;
bidb
,
bidh
,
tidx
,
params
.
h
)
;
clear
(
acc_dv
);
clear
(
acc_dv
);
clear
(
acc_dk
);
clear
(
acc_dk
);
...
@@ -886,9 +887,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -886,9 +887,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
static_assert
(
MMA_N_SdP
%
2
==
0
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
dropout
.
template
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>(
scores
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
scores
,
block_row_idx
,
block_col_idx
,
AtomLayoutMS
block_row_idx
,
block_col_idx
,
AtomLayoutMS
);
);
}
}
// Convert scores from fp32 to fp16/bf16
// Convert scores from fp32 to fp16/bf16
...
@@ -1395,8 +1395,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1395,8 +1395,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
dP_sum
);
++
mi
)
{
dP_sum
(
mi
)
=
sdPsum
(
get
<
0
>
(
taccScS_row
(
mi
)));
}
for
(
int
mi
=
0
;
mi
<
size
(
dP_sum
);
++
mi
)
{
dP_sum
(
mi
)
=
sdPsum
(
get
<
0
>
(
taccScS_row
(
mi
)));
}
auto
seed
=
params
.
rng_state
[
0
];
flash
::
Dropout
dropout
(
params
.
rng_state
[
0
],
params
.
rng_state
[
1
],
params
.
p_dropout_in_uint8_t
,
auto
offset
=
params
.
rng_state
[
1
]
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
;
bidb
,
bidh
,
tidx
,
params
.
h
)
;
clear
(
acc_dq
);
clear
(
acc_dq
);
...
@@ -1445,9 +1445,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
...
@@ -1445,9 +1445,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
static_assert
(
MMA_N_SdP
%
2
==
0
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
dropout
.
template
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>(
scores
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
scores
,
block_row_idx
,
block_col_idx
,
AtomLayoutMS
block_row_idx
,
block_col_idx
,
AtomLayoutMS
);
);
}
}
// Convert scores from fp32 to fp16/bf16
// Convert scores from fp32 to fp16/bf16
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
1274ec3e
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "kernel_traits.h"
#include "kernel_traits.h"
#include "utils.h"
#include "utils.h"
#include "softmax.h"
#include "softmax.h"
#include "dropout.h"
#include "alibi.h"
#include "alibi.h"
...
@@ -75,15 +76,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -75,15 +76,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
MMA_M
=
kBlockM
/
decltype
(
size
<
0
>
(
typename
Kernel_traits
::
TiledMma
::
TiledShape_MNK
{}))
::
value
;
constexpr
int
MMA_M
=
kBlockM
/
decltype
(
size
<
0
>
(
typename
Kernel_traits
::
TiledMma
::
TiledShape_MNK
{}))
::
value
;
auto
seed
s
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
auto
seed
_offset
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
unsigned
long
long
seed
=
std
::
get
<
0
>
(
seed
s
);
flash
::
Dropout
dropout
(
std
::
get
<
0
>
(
seed_offset
),
std
::
get
<
1
>
(
seed
_offset
),
params
.
p_dropout_in_uint8_t
,
unsigned
long
long
offset
=
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
;
bidb
,
bidh
,
tidx
,
params
.
h
)
;
// Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
// Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
// exit early and no one saves the rng states.
// exit early and no one saves the rng states.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
params
.
rng_state
[
0
]
=
s
eed
;
params
.
rng_state
[
0
]
=
s
td
::
get
<
0
>
(
seed_offset
)
;
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seed
s
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seed
_offset
);
}
}
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
...
@@ -404,16 +405,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -404,16 +405,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
acc_s_f16_drop
=
make_tensor
(
acc_s_f16
.
data
(),
rP
.
layout
());
Tensor
acc_s_f16_drop
=
make_tensor
(
acc_s_f16
.
data
(),
rP
.
layout
());
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
dropout
.
template
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>(
acc_s_f16_drop
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
acc_s_f16_drop
,
block_row_idx
,
block_col_idx
,
kNWarps
block_row_idx
,
block_col_idx
,
kNWarps
);
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
tSgS
.
data
()
=
tSgS
.
data
()
+
(
-
kBlockN
);
tSgS
.
data
()
=
tSgS
.
data
()
+
(
-
kBlockN
);
}
}
if
(
Is_dropout
)
{
if
(
Is_dropout
)
{
flash
::
apply_dropout
(
rP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
dropout
.
apply_dropout
(
rP
,
block_row_idx
,
block_col_idx
,
kNWarps
);
block_row_idx
,
block_col_idx
,
kNWarps
);
}
}
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
...
@@ -489,16 +488,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -489,16 +488,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
acc_s_f16_drop
=
make_tensor
(
acc_s_f16
.
data
(),
rP
.
layout
());
Tensor
acc_s_f16_drop
=
make_tensor
(
acc_s_f16
.
data
(),
rP
.
layout
());
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
dropout
.
template
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>(
acc_s_f16_drop
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
acc_s_f16_drop
,
block_row_idx
,
block_col_idx
,
kNWarps
block_row_idx
,
block_col_idx
,
kNWarps
);
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
tSgS
.
data
()
=
tSgS
.
data
()
+
(
-
kBlockN
);
tSgS
.
data
()
=
tSgS
.
data
()
+
(
-
kBlockN
);
}
}
if
(
Is_dropout
)
{
if
(
Is_dropout
)
{
flash
::
apply_dropout
(
rP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
dropout
.
apply_dropout
(
rP
,
block_row_idx
,
block_col_idx
,
kNWarps
);
block_row_idx
,
block_col_idx
,
kNWarps
);
}
}
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
...
...
csrc/flash_attn/src/softmax.h
View file @
1274ec3e
...
@@ -212,74 +212,4 @@ inline __device__ void apply_mask_causal_w_idx(
...
@@ -212,74 +212,4 @@ inline __device__ void apply_mask_causal_w_idx(
}
}
}
}
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor_
,
uint8_t
p_dropout_in_uint8_t
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
int
block_row_start
,
int
block_col_start
,
int
block_row_stride
)
{
// tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
tensor
=
make_tensor
(
tensor_
.
data
(),
flash
::
convert_layout_rowcol_dropout
(
tensor_
.
layout
()));
// tensor has shape (8, MMA_M, MMA_N / 2)
using
T
=
typename
Engine
::
value_type
;
auto
encode_dropout
=
[](
bool
keep
,
T
val
)
{
return
keep
?
val
:
(
encode_dropout_in_sign_bit
?
-
val
:
T
(
0
));
};
static_assert
(
decltype
(
size
<
2
>
(
tensor
))
::
value
%
2
==
0
);
const
uint16_t
p_dropout_8bit_in_uint16_t
=
uint16_t
(
p_dropout_in_uint8_t
);
const
uint32_t
p_dropout_8bit_in_uint32_t
=
(
uint32_t
(
p_dropout_8bit_in_uint16_t
)
<<
16
)
|
uint32_t
(
p_dropout_8bit_in_uint16_t
);
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tensor
);
++
m
,
block_row_start
+=
block_row_stride
)
{
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
#pragma unroll
for
(
int
n
=
0
;
n
<
size
<
2
>
(
tensor
)
/
2
;
++
n
,
++
rowcol
.
y
)
{
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4
random_uint4
=
flash
::
philox
(
seed
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
offset
);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t
(
&
rnd_8
)[
16
]
=
reinterpret_cast
<
uint8_t
(
&
)[
16
]
>
(
random_uint4
);
// Special implementation for 16-bit types: we duplicate the threshold to the
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if
(
!
encode_dropout_in_sign_bit
&&
(
std
::
is_same
<
T
,
cutlass
::
half_t
>::
value
||
std
::
is_same
<
T
,
cutlass
::
bfloat16_t
>::
value
))
{
uint16_t
rnd_16
[
16
];
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
rnd_16
[
i
]
=
uint16_t
(
rnd_8
[
i
]);
}
uint32_t
(
&
rnd_32
)[
8
]
=
reinterpret_cast
<
uint32_t
(
&
)[
8
]
>
(
rnd_16
);
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
Tensor
tensor_uint32
=
recast
<
uint32_t
>
(
tensor
(
_
,
m
,
n
*
2
+
j
));
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
uint32_t
mask
;
asm
volatile
(
"set.le.u32.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
mask
)
:
"r"
(
rnd_32
[
j
*
4
+
i
]),
"r"
(
p_dropout_8bit_in_uint32_t
));
tensor_uint32
(
i
)
&=
mask
;
}
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
else
{
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
tensor
(
i
,
m
,
n
*
2
+
j
)
=
encode_dropout
(
rnd_8
[
j
*
8
+
i
]
<=
p_dropout_in_uint8_t
,
tensor
(
i
,
m
,
n
*
2
+
j
));
}
Tensor
tensor_uint32
=
recast
<
uint32_t
>
(
tensor
(
_
,
m
,
n
*
2
+
j
));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
}
}
}
}
// namespace flash
}
// namespace flash
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment