Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
a65d5009
Unverified
Commit
a65d5009
authored
Jun 03, 2022
by
shenggan
Committed by
GitHub
Jun 03, 2022
Browse files
use template for fused softmax & add unittest for fused softmax (#26)
parent
771d4b83
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
171 additions
and
394 deletions
+171
-394
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
...del/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
+137
-394
tests/test_fastnn/test_softmax.py
tests/test_fastnn/test_softmax.py
+34
-0
No files found.
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
View file @
a65d5009
#include <c10/cuda/CUDAGuard.h>
#include <math_constants.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
...
...
@@ -30,7 +30,8 @@ __inline__ __device__ float WarpAllReduceSum(float val) {
////////////////
__global__
void
fastfold_softmax_fp32
(
float
*
input
,
float
*
output
,
long
long
rows
,
long
long
cols
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax
(
T
*
input
,
T
*
output
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -41,8 +42,7 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, long long row
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
...
...
@@ -51,72 +51,17 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, long long row
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
output
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
...
...
@@ -124,74 +69,47 @@ __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
__global__
void
fastfold_softmax_grad_fp32
(
float
*
d_output
,
float
*
output
,
float
*
d_input
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_d_output
=
d_output
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
];
dy_buf
[
i
]
=
row_d_output
[
lane_id
*
cols_per_thread
+
i
];
}
float
thread_sum
=
0.
f
;
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
input
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
at
::
Tensor
output
=
at
::
empty_like
(
input
);
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_d_input
[
lane_id
*
cols_this_thread
+
i
]
=
(
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
];
}
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax
<
float
>
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax
<
at
::
Half
><<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
return
output
;
}
__global__
void
fastfold_softmax_grad_bfp16
(
at
::
BFloat16
*
d_output
,
at
::
BFloat16
*
output
,
at
::
BFloat16
*
d_input
,
long
long
rows
,
long
long
cols
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -202,8 +120,7 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
...
...
@@ -213,56 +130,37 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_d_output
=
d_output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_d_input
=
d_input
+
row_offset
*
cols
;
T
*
row_d_output
=
d_output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
static_cast
<
float
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
float
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
y_buf
[
i
]
=
static_cast
<
T
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
T
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
}
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
static_cast
<
T
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
}
}
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
input
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
at
::
Tensor
output
=
at
::
empty_like
(
input
);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_fp32
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
{
fastfold_softmax_bfp16
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
return
output
;
}
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
)
{
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
output
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
at
::
Tensor
grad_input
=
at
::
empty_like
(
output
);
...
...
@@ -271,11 +169,15 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_grad_fp32
<<<
grid
,
block
>>>
((
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
else
{
fastfold_softmax_grad_bfp16
<<<
grid
,
block
>>>
(
fastfold_softmax_grad
<
float
><<<
grid
,
block
>>>
((
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
...
...
@@ -285,8 +187,9 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
////////////////
__global__
void
fastfold_softmax_scale_mask_fp32
(
float
*
input
,
float
*
mask
,
float
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax_scale_mask
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -297,8 +200,7 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
...
...
@@ -307,80 +209,21 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
1e9
;
}
else
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
]
*
scale
;
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_scale_mask_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
mask
,
at
::
BFloat16
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
...
...
@@ -388,23 +231,23 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
at
::
Tensor
fused_scale_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
at
::
Tensor
fused_scale_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
...
...
@@ -415,79 +258,26 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_fp32
<<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
fastfold_softmax_scale_mask_bfp16
<<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
fastfold_softmax_scale_mask
<
float
>
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask
<
at
::
BFloat16
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
return
output
;
}
__global__
void
fastfold_softmax_scale_mask_grad_fp32
(
float
*
d_output
,
float
*
output
,
float
*
d_input
,
float
*
mask
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_d_output
=
d_output
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
];
dy_buf
[
i
]
=
row_d_output
[
lane_id
*
cols_per_thread
+
i
];
}
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
else
{
row_d_input
=
0
;
}
}
}
}
__global__
void
fastfold_softmax_scale_mask_grad_bfp16
(
at
::
BFloat16
*
d_output
,
at
::
BFloat16
*
output
,
at
::
BFloat16
*
d_input
,
at
::
BFloat16
*
mask
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax_scale_mask_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
T
*
mask
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -498,8 +288,7 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
...
...
@@ -509,33 +298,33 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_d_output
=
d_output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_d_input
=
d_input
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
row_d_output
=
d_output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_d_input
=
d_input
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
static_cast
<
float
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
float
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
y_buf
[
i
]
=
static_cast
<
T
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
T
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
}
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]));
static_cast
<
T
>
(
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]));
}
else
{
row_d_input
=
0
;
}
...
...
@@ -544,7 +333,8 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
}
at
::
Tensor
fused_scale_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
...
...
@@ -555,11 +345,16 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_grad
_fp32
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_grad
<
float
>
<<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
fastfold_softmax_scale_mask_grad_bfp16
<<<
grid
,
block
>>>
(
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
...
...
@@ -570,9 +365,10 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
////////////////
__global__
void
fastfold_softmax_scale_mask_bias_fp32
(
float
*
input
,
float
*
mask
,
float
*
bias
,
float
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax_scale_mask_bias
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -583,69 +379,7 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
]
*
scale
+
bias_ptr
[
lane_id
*
cols_per_thread
+
i
];
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_scale_mask_bias_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
mask
,
at
::
BFloat16
*
bias
,
at
::
BFloat16
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
...
...
@@ -654,23 +388,23 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
at
::
BFloat16
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
+=
static_cast
<
float
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
+=
static_cast
<
T
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
...
...
@@ -678,17 +412,17 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
...
...
@@ -706,14 +440,18 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_bias
_fp32
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_bias
<
float
>
<<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
fastfold_softmax_scale_mask_bias_bfp16
<<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_bias
<
at
::
Half
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_bias
<
at
::
BFloat16
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
return
output
;
...
...
@@ -732,11 +470,16 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_grad
_fp32
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_grad
<
float
>
<<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
fastfold_softmax_scale_mask_grad_bfp16
<<<
grid
,
block
>>>
(
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
...
...
tests/test_fastnn/test_softmax.py
0 → 100644
View file @
a65d5009
import
torch
from
fastfold.model.fastnn.kernel
import
softmax
def
test_softmax
():
# [batch, dim]
test_shape
=
[[
64
,
64
],
[
64
,
128
],
[
64
,
129
],
[
64
,
1024
]]
test_dtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
test_device
=
torch
.
device
(
"cuda"
)
tolerance_eps
=
{
torch
.
float32
:
10e-5
,
torch
.
float16
:
10e-2
,
torch
.
bfloat16
:
10e-2
}
for
shape
in
test_shape
:
for
dtype
in
test_dtype
:
sample_input
=
torch
.
rand
(
shape
).
to
(
device
=
test_device
,
dtype
=
dtype
).
requires_grad_
(
True
)
sample_input_fastnn
=
torch
.
clone
(
sample_input
.
detach
()).
requires_grad_
(
True
)
# Forward
torch_out
=
torch
.
nn
.
functional
.
softmax
(
sample_input
,
dim
=-
1
)
fastnn_out
=
softmax
(
sample_input_fastnn
)
forward_error
=
torch
.
max
(
torch
.
abs
(
torch_out
-
fastnn_out
)).
cpu
().
item
()
assert
forward_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
# Backward
out_grad
=
torch
.
rand_like
(
torch_out
).
requires_grad_
(
False
)
torch_out
.
backward
(
out_grad
)
fastnn_out
.
backward
(
out_grad
)
backward_error
=
torch
.
max
(
torch
.
abs
(
sample_input
.
grad
-
sample_input_fastnn
.
grad
)).
cpu
().
item
()
assert
backward_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment