Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
771d4b83
Unverified
Commit
771d4b83
authored
Jun 03, 2022
by
shenggan
Committed by
GitHub
Jun 03, 2022
Browse files
use template in layernorm kernel & add unittest for fastnn layernorm (#25)
parent
ad1bbc52
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
86 deletions
+62
-86
fastfold/model/fastnn/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
.../fastnn/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
+21
-86
tests/__init__.py
tests/__init__.py
+0
-0
tests/test_fastnn/test_layernorm.py
tests/test_fastnn/test_layernorm.py
+41
-0
No files found.
fastfold/model/fastnn/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
View file @
771d4b83
...
@@ -58,78 +58,9 @@ __inline__ __device__ void WelfordWarpAllReduce(float thread_mean, float thread_
...
@@ -58,78 +58,9 @@ __inline__ __device__ void WelfordWarpAllReduce(float thread_mean, float thread_
*
count
=
__shfl_sync
(
0xffffffff
,
*
count
,
0
,
32
);
*
count
=
__shfl_sync
(
0xffffffff
,
*
count
,
0
,
32
);
}
}
__global__
void
fastfold_layernorm_fp32
(
float
*
input
,
float
*
output
,
float
*
gamma
,
float
*
beta
,
template
<
typename
T
>
float
*
mean
,
float
*
invvar
,
int
rows
,
int
cols
,
__global__
void
fastfold_layernorm
(
T
*
input
,
T
*
output
,
T
*
gamma
,
T
*
beta
,
float
*
mean
,
double
epsilon
)
{
float
*
invvar
,
int
rows
,
int
cols
,
double
epsilon
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
buf
[
32
];
float
thread_mean
=
0.
f
;
float
thread_m2
=
0.
f
;
float
thread_count
=
0.
f
;
float
warp_mean
;
float
warp_m2
;
float
warp_count
;
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
WelfordOnline
(
buf
[
i
],
&
thread_mean
,
&
thread_m2
,
&
thread_count
);
}
WelfordWarpAllReduce
(
thread_mean
,
thread_m2
,
thread_count
,
&
warp_mean
,
&
warp_m2
,
&
warp_count
);
float
row_mean
=
warp_mean
;
float
row_variance
=
max
(
warp_m2
/
warp_count
,
0.
f
);
float
row_inv_var
=
rsqrt
(
row_variance
+
epsilon
);
if
(
lane_id
==
0
)
{
mean
[
row_offset
]
=
row_mean
;
invvar
[
row_offset
]
=
row_inv_var
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
(
buf
[
i
]
-
row_mean
)
*
row_inv_var
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
buf
[
i
]
*
gamma
[
lane_id
*
cols_per_thread
+
i
]
+
beta
[
lane_id
*
cols_per_thread
+
i
];
}
}
}
__global__
void
fastfold_layernorm_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
output
,
at
::
BFloat16
*
gamma
,
at
::
BFloat16
*
beta
,
float
*
mean
,
float
*
invvar
,
int
rows
,
int
cols
,
double
epsilon
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
...
@@ -140,15 +71,13 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
...
@@ -140,15 +71,13 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
if
(
threadidx_y
==
last_y
)
{
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
}
else
if
(
threadidx_y
>
last_y
)
{
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
cols_this_thread
=
0
;
}
}
int
lane_id
=
threadidx_y
;
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
if
(
row_offset
<
rows
)
{
float
buf
[
32
];
float
buf
[
32
];
float
thread_mean
=
0.
f
;
float
thread_mean
=
0.
f
;
...
@@ -159,20 +88,21 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
...
@@ -159,20 +88,21 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
float
warp_m2
;
float
warp_m2
;
float
warp_count
;
float
warp_count
;
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
WelfordOnline
(
buf
[
i
],
&
thread_mean
,
&
thread_m2
,
&
thread_count
);
WelfordOnline
(
buf
[
i
],
&
thread_mean
,
&
thread_m2
,
&
thread_count
);
}
}
WelfordWarpAllReduce
(
thread_mean
,
thread_m2
,
thread_count
,
&
warp_mean
,
&
warp_m2
,
&
warp_count
);
WelfordWarpAllReduce
(
thread_mean
,
thread_m2
,
thread_count
,
&
warp_mean
,
&
warp_m2
,
&
warp_count
);
float
row_mean
=
warp_mean
;
float
row_mean
=
warp_mean
;
float
row_variance
=
max
(
warp_m2
/
warp_count
,
0.
f
);
float
row_variance
=
max
(
warp_m2
/
warp_count
,
0.
f
);
...
@@ -183,15 +113,15 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
...
@@ -183,15 +113,15 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
invvar
[
row_offset
]
=
row_inv_var
;
invvar
[
row_offset
]
=
row_inv_var
;
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
(
buf
[
i
]
-
row_mean
)
*
row_inv_var
;
buf
[
i
]
=
(
buf
[
i
]
-
row_mean
)
*
row_inv_var
;
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
buf
[
i
])
*
gamma
[
lane_id
*
cols_per_thread
+
i
]
+
static_cast
<
T
>
(
buf
[
i
])
*
gamma
[
lane_id
*
cols_per_thread
+
i
]
+
beta
[
lane_id
*
cols_per_thread
+
i
];
beta
[
lane_id
*
cols_per_thread
+
i
];
}
}
}
}
...
@@ -204,12 +134,17 @@ void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, a
...
@@ -204,12 +134,17 @@ void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, a
dim3
block
(
128
);
dim3
block
(
128
);
if
(
output
->
dtype
()
==
torch
::
kFloat32
)
{
if
(
output
->
dtype
()
==
torch
::
kFloat32
)
{
fastfold_layernorm
_fp32
<<<
grid
,
block
>>>
(
fastfold_layernorm
<
float
>
<<<
grid
,
block
>>>
(
(
float
*
)
input
->
data_ptr
(),
(
float
*
)
output
->
data_ptr
(),
(
float
*
)
gamma
->
data_ptr
(),
(
float
*
)
input
->
data_ptr
(),
(
float
*
)
output
->
data_ptr
(),
(
float
*
)
gamma
->
data_ptr
(),
(
float
*
)
beta
->
data_ptr
(),
(
float
*
)
mean
->
data_ptr
(),
(
float
*
)
invvar
->
data_ptr
(),
rows
,
(
float
*
)
beta
->
data_ptr
(),
(
float
*
)
mean
->
data_ptr
(),
(
float
*
)
invvar
->
data_ptr
(),
rows
,
cols
,
epsilon
);
cols
,
epsilon
);
}
else
{
}
else
if
(
output
->
dtype
()
==
torch
::
kFloat16
)
{
fastfold_layernorm_bfp16
<<<
grid
,
block
>>>
(
fastfold_layernorm
<
at
::
Half
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
->
data_ptr
(),
(
at
::
Half
*
)
output
->
data_ptr
(),
(
at
::
Half
*
)
gamma
->
data_ptr
(),
(
at
::
Half
*
)
beta
->
data_ptr
(),
(
float
*
)
mean
->
data_ptr
(),
(
float
*
)
invvar
->
data_ptr
(),
rows
,
cols
,
epsilon
);
}
else
if
(
output
->
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_layernorm
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
->
data_ptr
(),
(
at
::
BFloat16
*
)
output
->
data_ptr
(),
(
at
::
BFloat16
*
)
input
->
data_ptr
(),
(
at
::
BFloat16
*
)
output
->
data_ptr
(),
(
at
::
BFloat16
*
)
gamma
->
data_ptr
(),
(
at
::
BFloat16
*
)
beta
->
data_ptr
(),
(
at
::
BFloat16
*
)
gamma
->
data_ptr
(),
(
at
::
BFloat16
*
)
beta
->
data_ptr
(),
(
float
*
)
mean
->
data_ptr
(),
(
float
*
)
invvar
->
data_ptr
(),
rows
,
cols
,
epsilon
);
(
float
*
)
mean
->
data_ptr
(),
(
float
*
)
invvar
->
data_ptr
(),
rows
,
cols
,
epsilon
);
...
...
tests/__init__.py
0 → 100644
View file @
771d4b83
tests/test_fastnn/test_layernorm.py
0 → 100644
View file @
771d4b83
import
torch
from
fastfold.model.fastnn.kernel
import
LayerNorm
as
FastLayerNorm
def
test_layernorm
():
# [batch, dim]
test_shape
=
[[
64
,
64
],
[
64
,
128
],
[
64
,
129
],
[
64
,
1024
]]
test_dtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
test_device
=
torch
.
device
(
"cuda"
)
tolerance_eps
=
{
torch
.
float32
:
10e-5
,
torch
.
float16
:
10e-2
,
torch
.
bfloat16
:
10e-2
}
for
shape
in
test_shape
:
for
dtype
in
test_dtype
:
sample_input
=
torch
.
rand
(
shape
).
to
(
device
=
test_device
,
dtype
=
dtype
).
requires_grad_
(
False
)
dim_
=
sample_input
.
size
()[
-
1
]
torch_module
=
torch
.
nn
.
LayerNorm
(
normalized_shape
=
dim_
).
to
(
device
=
test_device
,
dtype
=
dtype
)
fastnn_module
=
FastLayerNorm
(
normalized_shape
=
dim_
).
to
(
device
=
test_device
,
dtype
=
dtype
)
# Forward
torch_out
=
torch_module
(
sample_input
)
fastnn_out
=
fastnn_module
(
sample_input
)
forward_error
=
torch
.
max
(
torch
.
abs
(
torch_out
-
fastnn_out
)).
cpu
().
item
()
assert
forward_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
# Backward
out_grad
=
torch
.
rand_like
(
torch_out
).
requires_grad_
(
False
)
torch_out
.
backward
(
out_grad
)
fastnn_out
.
backward
(
out_grad
)
backward_weight_error
=
torch
.
max
(
torch
.
abs
(
torch_module
.
weight
.
grad
-
fastnn_module
.
weight
.
grad
)).
cpu
().
item
()
assert
backward_weight_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
backward_bias_error
=
torch
.
max
(
torch
.
abs
(
torch_module
.
bias
.
grad
-
fastnn_module
.
bias
.
grad
)).
cpu
().
item
()
assert
backward_bias_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment