Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5db33051
Commit
5db33051
authored
Dec 12, 2022
by
Tri Dao
Browse files
[LayerNorm] Support taking subset of input or subset of output
parent
ae137ed1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
647 additions
and
172 deletions
+647
-172
csrc/layer_norm/ln.h
csrc/layer_norm/ln.h
+3
-0
csrc/layer_norm/ln_api.cpp
csrc/layer_norm/ln_api.cpp
+71
-15
csrc/layer_norm/ln_bwd_kernels.cuh
csrc/layer_norm/ln_bwd_kernels.cuh
+138
-103
csrc/layer_norm/ln_fwd_kernels.cuh
csrc/layer_norm/ln_fwd_kernels.cuh
+49
-33
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+119
-4
tests/ops/test_dropout_layer_norm.py
tests/ops/test_dropout_layer_norm.py
+267
-17
No files found.
csrc/layer_norm/ln.h
View file @
5db33051
...
@@ -66,11 +66,14 @@ struct ParamsBase {
...
@@ -66,11 +66,14 @@ struct ParamsBase {
void
*
gamma
;
void
*
gamma
;
void
*
rowscale
;
void
*
rowscale
;
void
*
colscale
;
void
*
colscale
;
void
*
x0_subset
;
void
*
z_subset
;
float
inverse_cols
;
float
inverse_cols
;
float
dropout_keep_p
;
float
dropout_keep_p
;
float
dropout_scale
;
float
dropout_scale
;
float
rowscale_const
;
// Multi-CTA workspace in gmem.
// Multi-CTA workspace in gmem.
void
*
workspace
;
void
*
workspace
;
...
...
csrc/layer_norm/ln_api.cpp
View file @
5db33051
...
@@ -84,9 +84,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -84,9 +84,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
beta
,
// hidden_size
const
at
::
Tensor
&
beta
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
colscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
colscale_
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
x0_subset_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
z_subset_
,
// BxS
const
float
dropout_p
,
const
float
dropout_p
,
const
float
epsilon
,
const
float
epsilon
,
const
float
rowscale_const
,
const
int64_t
z_numrows
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
bool
residual_in_fp32
bool
residual_in_fp32
)
{
)
{
...
@@ -99,14 +103,19 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -99,14 +103,19 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto
ctype
=
torch
::
kFloat32
;
auto
ctype
=
torch
::
kFloat32
;
auto
mtype
=
torch
::
kUInt8
;
auto
mtype
=
torch
::
kUInt8
;
TORCH_CHECK
(
beta
.
scalar_
type
()
==
wtype
);
TORCH_CHECK
(
beta
.
d
type
()
==
wtype
);
TORCH_CHECK
(
x0
.
is_cuda
())
TORCH_CHECK
(
x0
.
is_cuda
())
TORCH_CHECK
(
gamma
.
is_cuda
())
TORCH_CHECK
(
gamma
.
is_cuda
())
TORCH_CHECK
(
beta
.
is_cuda
())
TORCH_CHECK
(
beta
.
is_cuda
())
TORCH_CHECK
(
x0
.
is_contiguous
());
TORCH_CHECK
(
x0
.
is_contiguous
());
auto
sizes
=
x0
.
sizes
();
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
// Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
// blah is then deallocated.
std
::
vector
<
int64_t
>
sizes_vec
{
!
x0_subset_
.
has_value
()
?
x0
.
size
(
0
)
:
x0_subset_
.
value
().
size
(
0
),
x0
.
size
(
1
)};
auto
sizes
=
c10
::
IntArrayRef
(
sizes_vec
);
TORCH_CHECK
(
x0
.
dim
()
==
2
);
TORCH_CHECK
(
sizes
.
size
()
==
2
);
TORCH_CHECK
(
sizes
.
size
()
==
2
);
const
int
rows
=
sizes
[
0
];
const
int
rows
=
sizes
[
0
];
...
@@ -124,7 +133,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -124,7 +133,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto
rowscale
=
rowscale_
.
value
();
auto
rowscale
=
rowscale_
.
value
();
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
rows
});
TORCH_CHECK
(
rowscale
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
rowscale
.
dtype
()
==
itype
);
TORCH_CHECK
(
rowscale
.
dtype
()
==
itype
);
}
}
...
@@ -132,10 +141,25 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -132,10 +141,25 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto
colscale
=
colscale_
.
value
();
auto
colscale
=
colscale_
.
value
();
TORCH_CHECK
(
colscale
.
is_cuda
())
TORCH_CHECK
(
colscale
.
is_cuda
())
TORCH_CHECK
(
colscale
.
is_contiguous
());
TORCH_CHECK
(
colscale
.
is_contiguous
());
TORCH_CHECK
(
colscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
cols
});
TORCH_CHECK
(
colscale
.
sizes
()
==
c10
::
IntArrayRef
{
cols
});
TORCH_CHECK
(
colscale
.
dtype
()
==
wtype
);
TORCH_CHECK
(
colscale
.
dtype
()
==
wtype
);
}
}
if
(
x0_subset_
.
has_value
())
{
auto
x0_subset
=
x0_subset_
.
value
();
TORCH_CHECK
(
x0_subset
.
is_cuda
())
TORCH_CHECK
(
x0_subset
.
is_contiguous
());
TORCH_CHECK
(
x0_subset
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
x0_subset
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
z_subset_
.
has_value
());
auto
z_subset
=
z_subset_
.
value
();
TORCH_CHECK
(
z_subset
.
is_cuda
());
TORCH_CHECK
(
z_subset
.
is_contiguous
());
TORCH_CHECK
(
z_subset
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
z_subset
.
dtype
()
==
torch
::
kInt32
);
}
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
...
@@ -144,12 +168,12 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -144,12 +168,12 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto
opts
=
x0
.
options
();
auto
opts
=
x0
.
options
();
bool
save_x
=
x1_
.
has_value
()
||
(
dropout_p
>
0.
f
)
||
rowscale_
.
has_value
()
||
colscale_
.
has_value
()
||
(
itype
!=
rtype
);
bool
save_x
=
x1_
.
has_value
()
||
(
dropout_p
>
0.
f
)
||
rowscale_
.
has_value
()
||
colscale_
.
has_value
()
||
x0_subset_
.
has_value
()
||
(
itype
!=
rtype
);
at
::
Tensor
x
;
at
::
Tensor
x
;
if
(
save_x
)
{
x
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
rtype
));
}
if
(
save_x
)
{
x
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
rtype
));
}
at
::
Tensor
dmask
;
at
::
Tensor
dmask
;
if
(
dropout_p
>
0.
f
)
{
dmask
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
mtype
));
};
if
(
dropout_p
>
0.
f
)
{
dmask
=
torch
::
empty
(
x0
.
sizes
()
,
opts
.
dtype
(
mtype
));
};
auto
z
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
otype
));
auto
z
=
torch
::
empty
(
z_subset_
.
has_value
()
?
c10
::
IntArrayRef
{
z_numrows
,
cols
}
:
sizes
,
opts
.
dtype
(
otype
));
auto
mu
=
torch
::
empty
({
rows
},
opts
.
dtype
(
ctype
));
auto
mu
=
torch
::
empty
({
rows
},
opts
.
dtype
(
ctype
));
auto
rsigma
=
torch
::
empty
({
rows
},
opts
.
dtype
(
ctype
));
auto
rsigma
=
torch
::
empty
({
rows
},
opts
.
dtype
(
ctype
));
...
@@ -163,6 +187,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -163,6 +187,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
launch_params
.
params
.
x1
=
x1_
.
has_value
()
?
x1_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
x1
=
x1_
.
has_value
()
?
x1_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
colscale
=
colscale_
.
has_value
()
?
colscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
colscale
=
colscale_
.
has_value
()
?
colscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
x0_subset
=
x0_subset_
.
has_value
()
?
x0_subset_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
z_subset
=
z_subset_
.
has_value
()
?
z_subset_
.
value
().
data_ptr
()
:
nullptr
;
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
...
@@ -192,6 +218,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -192,6 +218,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params
.
epsilon
=
epsilon
;
params
.
epsilon
=
epsilon
;
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
rowscale_const
=
rowscale_const
;
if
(
dropout_p
>
0.
f
)
{
if
(
dropout_p
>
0.
f
)
{
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
...
@@ -230,8 +257,12 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -230,8 +257,12 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
const
at
::
Tensor
&
rsigma
,
// BxS, FP32!
const
at
::
Tensor
&
rsigma
,
// BxS, FP32!
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
colscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
colscale_
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
x0_subset_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
z_subset_
,
// BxS
const
float
dropout_p
,
const
float
dropout_p
,
const
float
rowscale_const
,
const
int64_t
x0_numrows
,
const
bool
has_residual
const
bool
has_residual
)
{
)
{
...
@@ -259,9 +290,16 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -259,9 +290,16 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto
sizes
=
x
.
sizes
();
auto
sizes
=
x
.
sizes
();
TORCH_CHECK
(
sizes
.
size
()
==
2
);
TORCH_CHECK
(
sizes
.
size
()
==
2
);
TORCH_CHECK
(
dz
.
sizes
()
==
sizes
);
auto
rows
=
sizes
[
0
];
auto
rows
=
sizes
[
0
];
auto
cols
=
sizes
[
1
];
auto
cols
=
sizes
[
1
];
TORCH_CHECK
(
dz
.
dim
()
==
2
);
TORCH_CHECK
(
dz
.
size
(
1
)
==
cols
);
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
// Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
// blah is then deallocated.
std
::
vector
<
int64_t
>
x0_sizes_vec
{
!
x0_subset_
.
has_value
()
?
rows
:
x0_numrows
,
cols
};
auto
x0_sizes
=
c10
::
IntArrayRef
(
x0_sizes_vec
);
if
(
dx_
.
has_value
())
{
if
(
dx_
.
has_value
())
{
auto
dx
=
dx_
.
value
();
auto
dx
=
dx_
.
value
();
...
@@ -276,14 +314,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -276,14 +314,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
TORCH_CHECK
(
dmask
.
dtype
()
==
mtype
);
TORCH_CHECK
(
dmask
.
dtype
()
==
mtype
);
TORCH_CHECK
(
dmask
.
is_cuda
());
TORCH_CHECK
(
dmask
.
is_cuda
());
TORCH_CHECK
(
dmask
.
is_contiguous
());
TORCH_CHECK
(
dmask
.
is_contiguous
());
TORCH_CHECK
(
dmask
.
sizes
()
==
sizes
);
TORCH_CHECK
(
dmask
.
sizes
()
==
x0_
sizes
);
}
}
if
(
rowscale_
.
has_value
())
{
if
(
rowscale_
.
has_value
())
{
auto
rowscale
=
rowscale_
.
value
();
auto
rowscale
=
rowscale_
.
value
();
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
rows
});
TORCH_CHECK
(
rowscale
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
rowscale
.
dtype
()
==
itype
);
TORCH_CHECK
(
rowscale
.
dtype
()
==
itype
);
}
}
...
@@ -291,17 +329,32 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -291,17 +329,32 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto
colscale
=
colscale_
.
value
();
auto
colscale
=
colscale_
.
value
();
TORCH_CHECK
(
colscale
.
is_cuda
())
TORCH_CHECK
(
colscale
.
is_cuda
())
TORCH_CHECK
(
colscale
.
is_contiguous
());
TORCH_CHECK
(
colscale
.
is_contiguous
());
TORCH_CHECK
(
colscale
.
sizes
()
==
std
::
vector
<
int64_t
>
{
cols
});
TORCH_CHECK
(
colscale
.
sizes
()
==
c10
::
IntArrayRef
{
cols
});
TORCH_CHECK
(
colscale
.
dtype
()
==
wtype
);
TORCH_CHECK
(
colscale
.
dtype
()
==
wtype
);
TORCH_CHECK
(
x0_
.
has_value
());
TORCH_CHECK
(
x0_
.
has_value
());
auto
x0
=
x0_
.
value
();
auto
x0
=
x0_
.
value
();
TORCH_CHECK
(
x0
.
is_cuda
())
TORCH_CHECK
(
x0
.
is_cuda
())
TORCH_CHECK
(
x0
.
is_contiguous
());
TORCH_CHECK
(
x0
.
is_contiguous
());
TORCH_CHECK
(
x0
.
sizes
()
==
sizes
);
TORCH_CHECK
(
x0
.
sizes
()
==
x0_
sizes
);
TORCH_CHECK
(
x0
.
dtype
()
==
itype
);
TORCH_CHECK
(
x0
.
dtype
()
==
itype
);
}
}
if
(
x0_subset_
.
has_value
())
{
auto
x0_subset
=
x0_subset_
.
value
();
TORCH_CHECK
(
x0_subset
.
is_cuda
())
TORCH_CHECK
(
x0_subset
.
is_contiguous
());
TORCH_CHECK
(
x0_subset
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
x0_subset
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
z_subset_
.
has_value
());
auto
z_subset
=
z_subset_
.
value
();
TORCH_CHECK
(
z_subset
.
is_cuda
());
TORCH_CHECK
(
z_subset
.
is_contiguous
());
TORCH_CHECK
(
z_subset
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
z_subset
.
dtype
()
==
torch
::
kInt32
);
}
auto
hidden_size
=
gamma
.
numel
();
auto
hidden_size
=
gamma
.
numel
();
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
...
@@ -313,7 +366,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -313,7 +366,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto
opts
=
x
.
options
();
auto
opts
=
x
.
options
();
auto
dx0
=
torch
::
empty
_like
(
x
,
opts
.
dtype
(
itype
));
auto
dx0
=
torch
::
empty
(
x0_sizes
,
opts
.
dtype
(
itype
));
at
::
Tensor
dx1
;
at
::
Tensor
dx1
;
if
(
has_residual
)
{
dx1
=
torch
::
empty_like
(
x
,
opts
.
dtype
(
rtype
));
}
if
(
has_residual
)
{
dx1
=
torch
::
empty_like
(
x
,
opts
.
dtype
(
rtype
));
}
auto
dgamma
=
torch
::
empty_like
(
gamma
);
auto
dgamma
=
torch
::
empty_like
(
gamma
);
...
@@ -331,6 +384,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -331,6 +384,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launch_params
.
params
.
dx1
=
has_residual
?
dx1
.
data_ptr
()
:
nullptr
;
launch_params
.
params
.
dx1
=
has_residual
?
dx1
.
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
colscale
=
colscale_
.
has_value
()
?
colscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
colscale
=
colscale_
.
has_value
()
?
colscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
x0_subset
=
x0_subset_
.
has_value
()
?
x0_subset_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
z_subset
=
z_subset_
.
has_value
()
?
z_subset_
.
value
().
data_ptr
()
:
nullptr
;
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
multiple
=
hidden_size
<=
1536
?
256
:
(
hidden_size
<=
3072
?
512
:
1024
);
const
int
multiple
=
hidden_size
<=
1536
?
256
:
(
hidden_size
<=
3072
?
512
:
1024
);
...
@@ -366,6 +421,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -366,6 +421,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
params
.
dcolscale_part
=
colscale_
.
has_value
()
?
dcolscale_part
.
data_ptr
()
:
nullptr
;
params
.
dcolscale_part
=
colscale_
.
has_value
()
?
dcolscale_part
.
data_ptr
()
:
nullptr
;
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
rowscale_const
=
rowscale_const
;
if
(
launch_params
.
barrier_size
>
0
)
{
if
(
launch_params
.
barrier_size
>
0
)
{
// TODO Any way to avoid this?
// TODO Any way to avoid this?
...
...
csrc/layer_norm/ln_bwd_kernels.cuh
View file @
5db33051
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
namespace
layer_norm
{
namespace
layer_norm
{
template
<
typename
Ktraits
,
bool
Prenorm
,
bool
Is_dropout
,
bool
Has_
residu
al
,
bool
Has_
colscale
,
bool
Is_even_cols
>
template
<
typename
Ktraits
,
bool
Is_dropout
,
bool
Has_
colsc
al
e
,
bool
Has_
subset
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_bwd_kernel
(
layer_norm
::
BwdParams
params
)
{
void
ln_bwd_kernel
(
layer_norm
::
BwdParams
params
)
{
...
@@ -37,6 +37,9 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -37,6 +37,9 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
extern
__shared__
char
smem_
[];
extern
__shared__
char
smem_
[];
const
bool
has_residual
=
params
.
dx1
!=
nullptr
;
const
bool
prenorm
=
params
.
dx
!=
nullptr
;
const
index_t
tidx
=
threadIdx
.
x
;
const
index_t
tidx
=
threadIdx
.
x
;
const
index_t
bidn
=
blockIdx
.
x
%
CTAS_PER_ROW
;
const
index_t
bidn
=
blockIdx
.
x
%
CTAS_PER_ROW
;
const
index_t
bidm
=
blockIdx
.
x
/
CTAS_PER_ROW
;
const
index_t
bidm
=
blockIdx
.
x
/
CTAS_PER_ROW
;
...
@@ -51,6 +54,10 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -51,6 +54,10 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
static_assert
(
COLS
==
THREADS_PER_ROW
*
LDGS
*
NUM_ELTS
*
CTAS_PER_ROW
);
static_assert
(
COLS
==
THREADS_PER_ROW
*
LDGS
*
NUM_ELTS
*
CTAS_PER_ROW
);
const
input_t
*
rowscale
=
static_cast
<
input_t
*>
(
params
.
rowscale
);
const
index_t
*
x0_subset
=
static_cast
<
index_t
*>
(
params
.
x0_subset
);
const
index_t
*
z_subset
=
static_cast
<
index_t
*>
(
params
.
z_subset
);
Cvec
dzy_sum
[
LDGS
];
Cvec
dzy_sum
[
LDGS
];
Cvec
dz_sum
[
LDGS
];
Cvec
dz_sum
[
LDGS
];
Cvec
dcolscale_sum
[
LDGS
];
Cvec
dcolscale_sum
[
LDGS
];
...
@@ -87,40 +94,62 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -87,40 +94,62 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
const
compute_t
mu_r
=
static_cast
<
const
compute_t
*>
(
params
.
mu
)[
row
];
const
compute_t
mu_r
=
static_cast
<
const
compute_t
*>
(
params
.
mu
)[
row
];
const
compute_t
rs_r
=
static_cast
<
const
compute_t
*>
(
params
.
rs
)[
row
];
const
compute_t
rs_r
=
static_cast
<
const
compute_t
*>
(
params
.
rs
)[
row
];
const
compute_t
rowscale_val
=
const
compute_t
rowscale_val
=
!
Has_subset
?
(
params
.
rowscale
==
nullptr
?
1.0
f
:
compute_t
(
rowscale
[
row
]))
:
params
.
rowscale_const
;
params
.
rowscale
==
nullptr
?
1.0
f
:
compute_t
(
static_cast
<
const
input_t
*>
(
params
.
rowscale
)[
row
]);
const
int
row_z
=
!
Has_subset
?
row
+
1
:
z_subset
[
row
];
const
int
row_x0
=
!
Has_subset
?
row
+
1
:
x0_subset
[
row
];
const
bool
load_dz
=
!
Has_subset
||
row_z
>
0
;
const
bool
save_dx0
=
!
Has_subset
||
row_x0
>
0
;
Mvec
dmask
[
LDGS
];
Mvec
dmask
[
LDGS
];
Rvec
dx
[
LDGS
];
Rvec
dx
[
LDGS
];
compute_t
dy
[
LDGS
*
NUM_ELTS
];
compute_t
dy
[
LDGS
*
NUM_ELTS
];
compute_t
y
[
LDGS
*
NUM_ELTS
];
compute_t
y
[
LDGS
*
NUM_ELTS
];
compute_t
mdy_local
=
0.
f
;
compute_t
mdy_local
=
0.
f
;
compute_t
mdyy_local
=
0.
f
;
compute_t
mdyy_local
=
0.
f
;
index_t
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
// If dz is not loaded, then dy should be 0 and we don't care about the value of y.
#pragma unroll
if
(
load_dz
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
index_t
idx_x
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
index_t
idx_z
=
!
Has_subset
?
idx_x
:
(
load_dz
?
(
row_z
-
1
)
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
:
0
);
Rvec
x
;
index_t
idx_x0
=
!
Has_subset
?
idx_x
:
(
save_dx0
?
(
row_x0
-
1
)
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
:
0
);
Ovec
dz
;
#pragma unroll
dz
.
load_from
(
params
.
dz
,
idx
);
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Prenorm
)
{
dx
[
it
].
load_from
(
params
.
dx
,
idx
);
}
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
x
.
load_from
(
params
.
x
,
idx
);
Rvec
x
;
if
(
Is_dropout
)
{
dmask
[
it
].
load_from
(
params
.
dmask
,
idx
);
}
Ovec
dz
;
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
dz
.
load_from
(
params
.
dz
,
!
Has_subset
?
idx_x
:
idx_z
);
#pragma unroll
if
(
prenorm
)
{
dx
[
it
].
load_from
(
params
.
dx
,
idx_x
);
}
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
x
.
load_from
(
params
.
x
,
idx_x
);
compute_t
x_tmp
=
x
.
data
.
elt
[
jt
];
if
(
Is_dropout
)
{
dmask
[
it
].
load_from
(
params
.
dmask
,
!
Has_subset
?
idx_x
:
idx_x0
);
}
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
mu_r
);
idx_x
+=
Ktraits
::
VEC_COLS_PER_LDG
;
compute_t
dy_tmp
=
compute_t
(
gamma
[
it
].
data
.
elt
[
jt
])
*
compute_t
(
dz
.
data
.
elt
[
jt
]);
idx_z
+=
Ktraits
::
VEC_COLS_PER_LDG
;
compute_t
dz_tmp
=
dz
.
data
.
elt
[
jt
];
idx_x0
+=
Ktraits
::
VEC_COLS_PER_LDG
;
#pragma unroll
mdy_local
+=
dy_tmp
;
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
mdyy_local
+=
dy_tmp
*
y_tmp
;
compute_t
x_tmp
=
x
.
data
.
elt
[
jt
];
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
mu_r
);
dy
[
it
*
NUM_ELTS
+
jt
]
=
dy_tmp
;
compute_t
dy_tmp
=
compute_t
(
gamma
[
it
].
data
.
elt
[
jt
])
*
compute_t
(
dz
.
data
.
elt
[
jt
]);
y
[
it
*
NUM_ELTS
+
jt
]
=
y_tmp
;
compute_t
dz_tmp
=
dz
.
data
.
elt
[
jt
];
dzy_sum
[
it
].
data
.
elt
[
jt
]
+=
dz_tmp
*
y_tmp
;
mdy_local
+=
dy_tmp
;
dz_sum
[
it
].
data
.
elt
[
jt
]
+=
dz_tmp
;
mdyy_local
+=
dy_tmp
*
y_tmp
;
dy
[
it
*
NUM_ELTS
+
jt
]
=
dy_tmp
;
y
[
it
*
NUM_ELTS
+
jt
]
=
y_tmp
;
dzy_sum
[
it
].
data
.
elt
[
jt
]
+=
dz_tmp
*
y_tmp
;
dz_sum
[
it
].
data
.
elt
[
jt
]
+=
dz_tmp
;
}
}
}
}
else
{
index_t
idx_x
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
index_t
idx_x0
=
!
Has_subset
?
idx_x
:
(
save_dx0
?
(
row_x0
-
1
)
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
:
0
);
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
if
(
prenorm
)
{
dx
[
it
].
load_from
(
params
.
dx
,
idx_x
);
}
if
(
Is_dropout
)
{
dmask
[
it
].
load_from
(
params
.
dmask
,
!
Has_subset
?
idx_x
:
idx_x0
);
}
idx_x
+=
Ktraits
::
VEC_COLS_PER_LDG
;
idx_x0
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
}
}
}
}
...
@@ -129,42 +158,51 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -129,42 +158,51 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
mdy_local
=
layer_norm
::
Get
<
0
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
params
.
inverse_cols
;
mdy_local
=
layer_norm
::
Get
<
0
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
params
.
inverse_cols
;
mdyy_local
=
layer_norm
::
Get
<
1
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
params
.
inverse_cols
;
mdyy_local
=
layer_norm
::
Get
<
1
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
params
.
inverse_cols
;
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
index_t
idx_x
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
index_t
idx_x0
=
!
Has_subset
?
idx_x
:
(
save_dx0
?
(
row_x0
-
1
)
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
:
0
);
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Ivec
dx0
;
Ivec
dx0
;
Rvec
dx1
;
Rvec
dx1
;
Ivec
x0
;
Ivec
x0
;
if
(
Has_colscale
)
{
x0
.
load_from
(
params
.
x0
,
idx
);
}
if
(
Has_colscale
&&
save_dx0
)
{
x0
.
load_from
(
params
.
x0
,
!
Has_subset
?
idx_x
:
idx_x0
);
}
#pragma unroll
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
dx_tmp_res
;
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
if
(
load_dz
)
{
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
+
mdy_local
));
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
dx_tmp_res
=
Prenorm
?
dx_tmp
+
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
dx_tmp
;
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
if
(
Has_residual
)
{
dx1
.
data
.
elt
[
jt
]
=
dx_tmp_res
;
}
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
+
mdy_local
));
compute_t
dx0_tmp_res
=
dx_tmp_res
*
rowscale_val
;
dx_tmp_res
=
prenorm
?
dx_tmp
+
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
dx_tmp
;
if
(
Is_dropout
)
{
dx0_tmp_res
*=
params
.
dropout_scale
;
if
(
Has_colscale
)
{
dcolscale_sum
[
it
].
data
.
elt
[
jt
]
+=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
*
compute_t
(
x0
.
data
.
elt
[
jt
])
:
0.
f
;
dx0
.
data
.
elt
[
jt
]
=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
*
compute_t
(
colscale
[
it
].
data
.
elt
[
jt
])
:
0.
f
;
}
else
{
dx0
.
data
.
elt
[
jt
]
=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
:
0.
f
;
}
}
else
{
}
else
{
if
(
Has_colscale
)
{
dx_tmp_res
=
prenorm
?
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
0.
f
;
dcolscale_sum
[
it
].
data
.
elt
[
jt
]
+=
dx0_tmp_res
*
compute_t
(
x0
.
data
.
elt
[
jt
]);
}
dx0
.
data
.
elt
[
jt
]
=
dx0_tmp_res
*
compute_t
(
colscale
[
it
].
data
.
elt
[
jt
]);
if
(
has_residual
)
{
dx1
.
data
.
elt
[
jt
]
=
dx_tmp_res
;
}
if
(
save_dx0
)
{
compute_t
dx0_tmp_res
=
dx_tmp_res
*
rowscale_val
;
if
(
Is_dropout
)
{
dx0_tmp_res
*=
params
.
dropout_scale
;
if
(
Has_colscale
)
{
dcolscale_sum
[
it
].
data
.
elt
[
jt
]
+=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
*
compute_t
(
x0
.
data
.
elt
[
jt
])
:
0.
f
;
dx0
.
data
.
elt
[
jt
]
=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
*
compute_t
(
colscale
[
it
].
data
.
elt
[
jt
])
:
0.
f
;
}
else
{
dx0
.
data
.
elt
[
jt
]
=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
:
0.
f
;
}
}
else
{
}
else
{
dx0
.
data
.
elt
[
jt
]
=
dx0_tmp_res
;
if
(
Has_colscale
)
{
dcolscale_sum
[
it
].
data
.
elt
[
jt
]
+=
dx0_tmp_res
*
compute_t
(
x0
.
data
.
elt
[
jt
]);
dx0
.
data
.
elt
[
jt
]
=
dx0_tmp_res
*
compute_t
(
colscale
[
it
].
data
.
elt
[
jt
]);
}
else
{
dx0
.
data
.
elt
[
jt
]
=
dx0_tmp_res
;
}
}
}
}
}
}
}
if
(
Has_residual
)
{
dx1
.
store_to
(
params
.
dx1
,
idx
);
}
if
(
has_residual
)
{
dx1
.
store_to
(
params
.
dx1
,
idx_x
);
}
dx0
.
store_to
(
params
.
dx0
,
idx
);
if
(
save_dx0
)
{
dx0
.
store_to
(
params
.
dx0
,
!
Has_subset
?
idx_x
:
idx_x0
);
}
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
idx_x
+=
Ktraits
::
VEC_COLS_PER_LDG
;
idx_x0
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
}
}
...
@@ -434,64 +472,61 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
...
@@ -434,64 +472,61 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
WARPS_N
,
WARPS_N
,
BYTES_PER_LDG_MAIN
BYTES_PER_LDG_MAIN
>
;
>
;
bool
prenorm
=
launch_params
.
params
.
dx
!=
nullptr
;
bool
is_dropout
=
launch_params
.
params
.
dropout_keep_p
<
1.
f
;
bool
is_dropout
=
launch_params
.
params
.
dropout_keep_p
<
1.
f
;
bool
has_residual
=
launch_params
.
params
.
dx1
!=
nullptr
;
bool
has_colscale
=
launch_params
.
params
.
colscale
!=
nullptr
;
bool
has_colscale
=
launch_params
.
params
.
colscale
!=
nullptr
;
bool
has_subset
=
launch_params
.
params
.
x0_subset
!=
nullptr
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
BOOL_SWITCH
(
prenorm
,
PrenormConst
,
[
&
]
{
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
has_colscale
,
HasColscaleConst
,
[
&
]
{
BOOL_SWITCH
(
has_residual
,
HasResidualConst
,
[
&
]
{
BOOL_SWITCH
(
has_subset
,
HasSubsetConst
,
[
&
]
{
BOOL_SWITCH
(
has_colscale
,
HasColscaleConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
auto
kernel
=
&
ln_bwd_kernel
<
Kernel_traits
,
IsDropoutConst
,
HasColscaleConst
,
HasSubsetConst
,
IsEvenColsConst
>
;
auto
kernel
=
&
ln_bwd_kernel
<
Kernel_traits
,
PrenormConst
,
IsDropoutConst
,
HasResidualConst
,
HasColscaleConst
,
IsEvenColsConst
>
;
if
(
configure_params
)
{
if
(
configure_params
)
{
int
ctas_per_sm
;
int
ctas_per_sm
;
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
));
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
));
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
launch_params
.
barrier_size
=
0
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
launch_params
.
workspace_bytes
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
CTAS_PER_ROW
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
reduce_t
)
*
sizeof
(
typename
Kernel_traits
::
reduce_t
)
*
2
;
*
2
;
}
return
;
}
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
}
}
auto
stream
=
launch_params
.
stream
;
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
}
else
{
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES
,
stream
);
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES
,
stream
);
}
}
using
Kernel_traits_f
=
layer_norm
::
Kernel_traits_finalize
<
HIDDEN_SIZE
,
using
Kernel_traits_f
=
layer_norm
::
Kernel_traits_finalize
<
HIDDEN_SIZE
,
weight_t
,
weight_t
,
input_t
,
input_t
,
residual_t
,
residual_t
,
output_t
,
output_t
,
compute_t
,
compute_t
,
index_t
,
index_t
,
HasColscaleConst
,
HasColscaleConst
,
32
*
32
,
// THREADS_PER_CTA
32
*
32
,
// THREADS_PER_CTA
BYTES_PER_LDG_FINAL
>
;
BYTES_PER_LDG_FINAL
>
;
auto
kernel_f
=
&
layer_norm
::
ln_bwd_finalize_kernel
<
Kernel_traits_f
,
HasColscaleConst
,
IsEvenColsConst
>
;
auto
kernel_f
=
&
layer_norm
::
ln_bwd_finalize_kernel
<
Kernel_traits_f
,
HasColscaleConst
,
IsEvenColsConst
>
;
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
});
});
});
});
});
});
});
...
...
csrc/layer_norm/ln_fwd_kernels.cuh
View file @
5db33051
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
namespace
layer_norm
{
namespace
layer_norm
{
template
<
typename
Ktraits
,
bool
Is_dropout
,
bool
Has_
residu
al
,
bool
Has_
colscale
,
bool
Is_even_cols
>
template
<
typename
Ktraits
,
bool
Is_dropout
,
bool
Has_
colsc
al
e
,
bool
Has_
subset
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_fwd_kernel
(
FwdParams
params
)
{
void
ln_fwd_kernel
(
FwdParams
params
)
{
...
@@ -46,7 +46,8 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -46,7 +46,8 @@ void ln_fwd_kernel(FwdParams params) {
using
Stats
=
typename
Ktraits
::
Stats
;
using
Stats
=
typename
Ktraits
::
Stats
;
using
stats_t
=
typename
Stats
::
stats_t
;
using
stats_t
=
typename
Stats
::
stats_t
;
const
bool
save_x
=
Has_residual
||
Is_dropout
||
Has_colscale
||
(
params
.
rowscale
!=
nullptr
)
||
!
(
std
::
is_same
<
input_t
,
residual_t
>::
value
);
const
bool
has_residual
=
params
.
x1
!=
nullptr
;
const
bool
save_x
=
has_residual
||
Is_dropout
||
Has_colscale
||
(
params
.
rowscale
!=
nullptr
)
||
Has_subset
||
!
(
std
::
is_same
<
input_t
,
residual_t
>::
value
);
extern
__shared__
char
smem_
[];
extern
__shared__
char
smem_
[];
...
@@ -67,6 +68,8 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -67,6 +68,8 @@ void ln_fwd_kernel(FwdParams params) {
compute_t
*
rs_ptr
=
static_cast
<
compute_t
*>
(
params
.
rs
);
compute_t
*
rs_ptr
=
static_cast
<
compute_t
*>
(
params
.
rs
);
const
input_t
*
rowscale
=
static_cast
<
input_t
*>
(
params
.
rowscale
);
const
input_t
*
rowscale
=
static_cast
<
input_t
*>
(
params
.
rowscale
);
const
index_t
*
x0_subset
=
static_cast
<
index_t
*>
(
params
.
x0_subset
);
const
index_t
*
z_subset
=
static_cast
<
index_t
*>
(
params
.
z_subset
);
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
...
@@ -93,8 +96,12 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -93,8 +96,12 @@ void ln_fwd_kernel(FwdParams params) {
}
}
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
const
compute_t
rowscale_val
=
params
.
rowscale
==
nullptr
?
1.0
f
:
compute_t
(
rowscale
[
row
]);
const
compute_t
rowscale_val
=
!
Has_subset
?
(
params
.
rowscale
==
nullptr
?
1.0
f
:
compute_t
(
rowscale
[
row
]))
:
params
.
rowscale_const
;
index_t
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
const
int
row_x0
=
!
Has_subset
?
row
+
1
:
x0_subset
[
row
];
const
int
row_z
=
!
Has_subset
?
row
+
1
:
z_subset
[
row
];
const
bool
load_x0
=
!
Has_subset
||
row_x0
>
0
;
index_t
idx_x
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
index_t
idx_x0
=
!
Has_subset
?
idx_x
:
(
load_x0
?
(
row_x0
-
1
)
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
:
0
);
compute_t
xf
[
LDGS
*
NUM_ELTS
];
compute_t
xf
[
LDGS
*
NUM_ELTS
];
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
...
@@ -103,24 +110,30 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -103,24 +110,30 @@ void ln_fwd_kernel(FwdParams params) {
Rvec
x1
;
Rvec
x1
;
Rvec
x
;
Rvec
x
;
Mvec
dmask
;
Mvec
dmask
;
x0
.
load_from
(
params
.
x0
,
idx
);
if
(
load_x0
)
{
x0
.
load_from
(
params
.
x0
,
!
Has_subset
?
idx_x
:
idx_x0
);
}
if
(
H
as_residual
)
{
x1
.
load_from
(
params
.
x1
,
idx
);
}
if
(
h
as_residual
)
{
x1
.
load_from
(
params
.
x1
,
idx
_x
);
}
#pragma unroll
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
// the more efficient curand_uniform4.
// the more efficient curand_uniform4.
mask_t
keep
=
!
Is_dropout
?
true
:
curand_uniform
(
&
state
)
<=
params
.
dropout_keep_p
;
compute_t
x_ij
;
compute_t
x0_ij
=
compute_t
(
x0
.
data
.
elt
[
jt
])
*
rowscale_val
;
if
(
load_x0
)
{
x0_ij
=
keep
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
:
0.0
f
;
mask_t
keep
=
!
Is_dropout
?
true
:
curand_uniform
(
&
state
)
<=
params
.
dropout_keep_p
;
if
(
Has_colscale
)
{
x0_ij
*=
compute_t
(
colscale
[
it
].
data
.
elt
[
jt
]);
}
if
(
Is_dropout
)
{
dmask
.
data
.
elt
[
jt
]
=
keep
;
}
compute_t
x_ij
=
Has_residual
?
x0_ij
+
compute_t
(
x1
.
data
.
elt
[
jt
])
:
x0_ij
;
compute_t
x0_ij
=
compute_t
(
x0
.
data
.
elt
[
jt
])
*
rowscale_val
;
x0_ij
=
keep
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
:
0.0
f
;
if
(
Has_colscale
)
{
x0_ij
*=
compute_t
(
colscale
[
it
].
data
.
elt
[
jt
]);
}
x_ij
=
has_residual
?
x0_ij
+
compute_t
(
x1
.
data
.
elt
[
jt
])
:
x0_ij
;
}
else
{
x_ij
=
has_residual
?
compute_t
(
x1
.
data
.
elt
[
jt
])
:
0.
f
;
}
if
(
save_x
)
{
x
.
data
.
elt
[
jt
]
=
x_ij
;
}
if
(
save_x
)
{
x
.
data
.
elt
[
jt
]
=
x_ij
;
}
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
if
(
Is_dropout
)
{
dmask
.
data
.
elt
[
jt
]
=
keep
;
}
}
}
if
(
save_x
)
{
x
.
store_to
(
params
.
x
,
idx
);
}
if
(
save_x
)
{
x
.
store_to
(
params
.
x
,
idx_x
);
}
if
(
Is_dropout
)
{
dmask
.
store_to
(
params
.
dmask
,
idx
);
}
if
(
Is_dropout
&&
load_x0
)
{
dmask
.
store_to
(
params
.
dmask
,
!
Has_subset
?
idx_x
:
idx_x0
);
}
idx
+=
VEC_COLS_PER_LDG
;
idx_x
+=
VEC_COLS_PER_LDG
;
idx_x0
+=
VEC_COLS_PER_LDG
;
}
}
}
}
...
@@ -152,20 +165,23 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -152,20 +165,23 @@ void ln_fwd_kernel(FwdParams params) {
rs_ptr
[
row
]
=
rs
;
rs_ptr
[
row
]
=
rs
;
}
}
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
const
bool
save_z
=
!
Has_subset
||
row_z
>
0
;
#pragma unroll
if
(
save_z
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
index_t
idx_z
=
(
!
Has_subset
?
row
:
(
row_z
-
1
))
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
#pragma unroll
Ovec
z
;
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
Ovec
z
;
compute_t
y_ij
=
compute_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
mu
));
#pragma unroll
compute_t
g_ij
=
gamma
[
it
].
data
.
elt
[
jt
];
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
compute_t
y_ij
=
compute_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
mu
));
z
.
data
.
elt
[
jt
]
=
output_t
(
g_ij
*
y_ij
+
b_ij
);
compute_t
g_ij
=
gamma
[
it
].
data
.
elt
[
jt
];
compute_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
z
.
data
.
elt
[
jt
]
=
output_t
(
g_ij
*
y_ij
+
b_ij
);
}
z
.
store_to
(
params
.
z
,
idx_z
);
idx_z
+=
VEC_COLS_PER_LDG
;
}
}
z
.
store_to
(
params
.
z
,
idx
);
idx
+=
VEC_COLS_PER_LDG
;
}
}
}
}
...
@@ -203,14 +219,14 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
...
@@ -203,14 +219,14 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
WARPS_N
,
WARPS_N
,
BYTES_PER_LDG
BYTES_PER_LDG
>
;
>
;
bool
has_residual
=
launch_params
.
params
.
x1
!=
nullptr
;
bool
has_colscale
=
launch_params
.
params
.
colscale
!=
nullptr
;
bool
has_colscale
=
launch_params
.
params
.
colscale
!=
nullptr
;
bool
has_subset
=
launch_params
.
params
.
x0_subset
!=
nullptr
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
BOOL_SWITCH
(
launch_params
.
params
.
dropout_keep_p
<
1.
f
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
launch_params
.
params
.
dropout_keep_p
<
1.
f
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
has_
residu
al
,
Has
Residu
alConst
,
[
&
]
{
BOOL_SWITCH
(
has_
colsc
al
e
,
Has
Colsc
al
e
Const
,
[
&
]
{
BOOL_SWITCH
(
has_
colscale
,
HasColscale
Const
,
[
&
]
{
BOOL_SWITCH
(
has_
subset
,
HasSubset
Const
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
auto
kernel
=
&
ln_fwd_kernel
<
Kernel_traits
,
IsDropoutConst
,
Has
Residu
alConst
,
Has
Colscale
Const
,
IsEvenColsConst
>
;
auto
kernel
=
&
ln_fwd_kernel
<
Kernel_traits
,
IsDropoutConst
,
Has
Colsc
al
e
Const
,
Has
Subset
Const
,
IsEvenColsConst
>
;
if
(
configure_params
)
{
if
(
configure_params
)
{
int
ctas_per_sm
;
int
ctas_per_sm
;
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
...
...
flash_attn/ops/layer_norm.py
View file @
5db33051
...
@@ -16,7 +16,8 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
...
@@ -16,7 +16,8 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
x1mat
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
None
,
residual_in_fp32
x0mat
,
x1mat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
1.0
,
0
,
None
,
residual_in_fp32
)
)
# dmask is None if dropout_p == 0.0
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
...
@@ -36,12 +37,59 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
...
@@ -36,12 +37,59 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
colscale
=
colscale
.
view
(
-
1
)
if
colscale
is
not
None
else
None
if
colscale
is
not
None
:
if
colscale
is
not
None
:
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
has_residual
dropout_p
,
1.0
,
0
,
has_residual
)
# dx1mat is None if not has_residual
if
colscale
is
None
:
return
dx0mat
,
dx1mat
,
dgamma
,
dbeta
else
:
dcolscale
=
rest
[
0
]
return
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
dcolscale
def
_dropout_add_layer_norm_subset_forward
(
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
):
""" Assume that arguments are contiguous
"""
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
x1mat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size
=
gamma
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dzmat
=
dz
.
view
(
-
1
,
hidden_size
)
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
if
x0
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
if
colscale
is
not
None
:
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
)
)
# dx1mat is None if not has_residual
# dx1mat is None if not has_residual
if
colscale
is
None
:
if
colscale
is
None
:
...
@@ -98,6 +146,60 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
...
@@ -98,6 +146,60 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
return
dx0
,
dx1
,
dgamma
,
dbeta
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
return
dx0
,
dx1
,
dgamma
,
dbeta
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
class
DropoutAddLayerNormSubsetFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
beta
=
beta
.
contiguous
()
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
ctx
.
save_for_backward
(
xmat
.
view
(
x_shape
),
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
rowscale_const
=
rowscale_const
ctx
.
x0_numrows
=
x0
.
shape
[:
-
1
].
numel
()
ctx
.
has_residual
=
x1
is
not
None
z_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
if
not
return_dmask
:
return
(
zmat
.
view
(
z_shape
)
if
not
prenorm
else
(
zmat
.
view
(
z_shape
),
xmat
.
view
(
x0
.
shape
)))
else
:
z
=
zmat
.
view
(
z_shape
)
dmask
=
(
dmask
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
ctx
.
mark_non_differentiable
(
dmask
)
return
((
z
,
dmask
)
if
not
prenorm
else
(
z
,
xmat
.
view
(
x_shape
),
dmask
))
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
# assert dz.is_contiguous()
dz
=
dz
.
contiguous
()
# this happens!
dx
=
args
[
0
].
contiguous
()
if
ctx
.
prenorm
else
None
x
,
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
=
ctx
.
saved_tensors
# x0 is None if colscale is None
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
)
dx0
=
dx0mat
.
view
(
-
1
,
*
x
.
shape
[
1
:])
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dx1
,
dgamma
,
dbeta
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
dropout_add_layer_norm
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
def
dropout_add_layer_norm
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
return_dropout_mask
=
False
):
...
@@ -110,6 +212,19 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
...
@@ -110,6 +212,19 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
)
)
def
dropout_add_layer_norm_subset
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
return_dropout_mask
)
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
device
=
None
,
dtype
=
None
):
...
...
tests/ops/test_dropout_layer_norm.py
View file @
5db33051
...
@@ -4,9 +4,10 @@ import torch
...
@@ -4,9 +4,10 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
pytest
import
pytest
from
einops
import
rearrange
from
einops
import
rearrange
,
repeat
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNorm
,
dropout_add_layer_norm
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNorm
,
dropout_add_layer_norm
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_subset
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
...
@@ -130,6 +131,8 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
...
@@ -130,6 +131,8 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1_pt
.
detach
().
clone
().
float
().
requires_grad_
()
x1_ref
=
x1_pt
.
detach
().
clone
().
float
().
requires_grad_
()
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -148,22 +151,23 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
...
@@ -148,22 +151,23 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_colscale', [True, False])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_rowscale', [True, False])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_residual', [True, False])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.37
,
0.0
])
# @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
# @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@
pytest
.
mark
.
parametrize
(
'input_dtype,residual_dtype'
,
# @pytest.mark.parametrize('input_dtype,residual_dtype',
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
# [(torch.float16, torch.float16), (torch.float16, torch.float32),
(
torch
.
float32
,
torch
.
float32
)]
# (torch.float32, torch.float32)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
])
# @pytest.mark.parametrize('has_rowscale', [False])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
False
])
# @pytest.mark.parametrize('has_residual', [False])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
])
# @pytest.mark.parametrize('dropout_p', [0.0])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.0
])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'input_dtype,residual_dtype'
,
[(
torch
.
float32
,
torch
.
float32
)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
256
])
def
test_dropout_layer_norm_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
def
test_dropout_layer_norm_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
):
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
...
@@ -205,6 +209,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
...
@@ -205,6 +209,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
dtype
=
weight_dtype
)
...
@@ -271,6 +277,8 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
...
@@ -271,6 +277,8 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1_pt
.
detach
().
clone
().
float
().
requires_grad_
()
x1_ref
=
x1_pt
.
detach
().
clone
().
float
().
requires_grad_
()
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
dtype
=
weight_dtype
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
...
@@ -289,3 +297,245 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
...
@@ -289,3 +297,245 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
out_ref
=
model_ref
(
residual_ref
)
out_ref
=
model_ref
(
residual_ref
)
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
assert
(
residual
-
residual_ref
).
abs
().
max
()
<=
4
*
(
residual_pt
-
residual_ref
).
abs
().
max
()
+
1e-4
assert
(
residual
-
residual_ref
).
abs
().
max
()
<=
4
*
(
residual_pt
-
residual_ref
).
abs
().
max
()
+
1e-4
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.37
,
0.0
])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'input_dtype,residual_dtype'
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_subset_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_colscale
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
device
=
'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
2e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
drop_path_rate
=
0.4
drop_path_scale
=
1
/
(
1
-
drop_path_rate
)
def
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch
=
torch
.
rand
(
batch_size
)
<
1
-
drop_path_rate
numrows
=
(
mask_batch
).
sum
().
item
()
*
seqlen
mask_batch
=
mask_batch
.
to
(
device
=
device
,
non_blocking
=
True
)
mask_batch_seqlen
=
repeat
(
mask_batch
,
'b -> (b s)'
,
s
=
seqlen
)
subset
=
torch
.
cumsum
(
mask_batch_seqlen
,
dim
=
0
,
dtype
=
torch
.
int32
).
masked_fill_
(
~
mask_batch_seqlen
,
0
)
return
mask_batch
,
numrows
,
rearrange
(
subset
,
'(b s) -> b s'
,
b
=
batch_size
)
x0_mask_batch
,
x0_numrows
,
x0_subset
=
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
)
out_mask_batch
,
out_numrows
,
out_subset
=
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
)
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
()[
x0_mask_batch
].
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_colscale
:
colscale
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
colscale_pt
=
colscale
.
detach
().
clone
().
requires_grad_
()
colscale_ref
=
colscale
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
colscale
=
None
if
has_residual
:
x1_pt
=
torch
.
randn_like
(
x0_pt
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
x1
=
None
if
has_colscale
:
x0_scaled_pt
=
x0_pt
*
colscale_pt
x0_scaled_ref
=
x0_ref
*
colscale_ref
else
:
x0_scaled_pt
=
x0_pt
x0_scaled_ref
=
x0_ref
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
False
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
dmask
=
dropout_add_layer_norm_subset
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
layerscale
=
colscale
,
x0_subset
=
x0_subset
,
out_subset
=
out_subset
,
rowscale_const
=
drop_path_scale
,
out_numrows
=
out_numrows
,
prenorm
=
False
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
x0_scaled_pt
=
x0_scaled_pt
.
masked_fill
(
repeat
(
~
x0_mask_batch
,
'b -> b s d'
,
s
=
seqlen
,
d
=
hidden_size
),
0
)
*
drop_path_scale
x0_scaled_ref
=
x0_scaled_ref
.
masked_fill
(
repeat
(
~
x0_mask_batch
,
'b -> b s d'
,
s
=
seqlen
,
d
=
hidden_size
),
0
)
*
drop_path_scale
dmask_expanded
=
torch
.
zeros_like
(
x0_pt
,
dtype
=
torch
.
uint8
)
dmask_expanded
[
x0_mask_batch
]
=
dmask
if
has_residual
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
x1_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
x1_ref
else
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
out_pt
=
model_pt
(
residual_pt
.
to
(
dtype
=
weight_dtype
)).
to
(
dtype
=
input_dtype
)[
out_mask_batch
]
out_ref
=
model_ref
(
residual_ref
)[
out_mask_batch
]
assert
out
.
dtype
==
input_dtype
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
g
=
torch
.
randn_like
(
out
)
/
batch_size
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
out_ref
.
backward
(
g
)
assert
(
x0
.
grad
-
x0_ref
.
grad
[
x0_mask_batch
]).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
)[
x0_mask_batch
].
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.37
,
0.0
])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'input_dtype,residual_dtype'
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_subset_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_colscale
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
device
=
'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
2e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
drop_path_rate
=
0.4
drop_path_scale
=
1
/
(
1
-
drop_path_rate
)
def
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch
=
torch
.
rand
(
batch_size
)
<
1
-
drop_path_rate
numrows
=
(
mask_batch
).
sum
().
item
()
*
seqlen
mask_batch
=
mask_batch
.
to
(
device
=
device
,
non_blocking
=
True
)
mask_batch_seqlen
=
repeat
(
mask_batch
,
'b -> (b s)'
,
s
=
seqlen
)
subset
=
torch
.
cumsum
(
mask_batch_seqlen
,
dim
=
0
,
dtype
=
torch
.
int32
).
masked_fill_
(
~
mask_batch_seqlen
,
0
)
return
mask_batch
,
numrows
,
rearrange
(
subset
,
'(b s) -> b s'
,
b
=
batch_size
)
x0_mask_batch
,
x0_numrows
,
x0_subset
=
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
)
out_mask_batch
,
out_numrows
,
out_subset
=
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
)
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
()[
x0_mask_batch
].
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_colscale
:
colscale
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
colscale_pt
=
colscale
.
detach
().
clone
().
requires_grad_
()
colscale_ref
=
colscale
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
colscale
=
None
if
has_residual
:
x1_pt
=
torch
.
randn_like
(
x0_pt
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
x1
=
None
if
has_colscale
:
x0_scaled_pt
=
x0_pt
*
colscale_pt
x0_scaled_ref
=
x0_ref
*
colscale_ref
else
:
x0_scaled_pt
=
x0_pt
x0_scaled_ref
=
x0_ref
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
residual
,
dmask
=
dropout_add_layer_norm_subset
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
layerscale
=
colscale
,
x0_subset
=
x0_subset
,
out_subset
=
out_subset
,
rowscale_const
=
drop_path_scale
,
out_numrows
=
out_numrows
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
x0_scaled_pt
=
x0_scaled_pt
.
masked_fill
(
repeat
(
~
x0_mask_batch
,
'b -> b s d'
,
s
=
seqlen
,
d
=
hidden_size
),
0
)
*
drop_path_scale
x0_scaled_ref
=
x0_scaled_ref
.
masked_fill
(
repeat
(
~
x0_mask_batch
,
'b -> b s d'
,
s
=
seqlen
,
d
=
hidden_size
),
0
)
*
drop_path_scale
dmask_expanded
=
torch
.
zeros_like
(
x0_pt
,
dtype
=
torch
.
uint8
)
dmask_expanded
[
x0_mask_batch
]
=
dmask
if
has_residual
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
x1_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
x1_ref
else
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
out_pt
=
model_pt
(
residual_pt
.
to
(
dtype
=
weight_dtype
)).
to
(
dtype
=
input_dtype
)[
out_mask_batch
]
out_ref
=
model_ref
(
residual_ref
)[
out_mask_batch
]
assert
out
.
dtype
==
input_dtype
assert
residual
.
dtype
==
residual_dtype
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
assert
(
residual
-
residual_ref
).
abs
().
max
()
<=
4
*
(
residual_pt
-
residual_ref
).
abs
().
max
()
+
1e-4
g
=
torch
.
randn_like
(
out
)
/
batch_size
(
out_pt
*
F
.
sigmoid
(
residual_pt
[
out_mask_batch
])
+
residual_pt
.
mean
(
0
,
keepdim
=
True
)).
backward
(
g
)
(
out
*
F
.
sigmoid
(
residual
[
out_mask_batch
])
+
residual
.
mean
(
0
,
keepdim
=
True
)).
backward
(
g
)
(
out_ref
*
F
.
sigmoid
(
residual_ref
[
out_mask_batch
].
to
(
dtype
=
residual_dtype
))
+
residual_ref
.
mean
(
0
,
keepdim
=
True
)).
backward
(
g
)
assert
(
x0
.
grad
-
x0_ref
.
grad
[
x0_mask_batch
]).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
)[
x0_mask_batch
].
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment