Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
a65d5009
Unverified
Commit
a65d5009
authored
Jun 03, 2022
by
shenggan
Committed by
GitHub
Jun 03, 2022
Browse files
use template for fused softmax & add unittest for fused softmax (#26)
parent
771d4b83
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
171 additions
and
394 deletions
+171
-394
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
...del/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
+137
-394
tests/test_fastnn/test_softmax.py
tests/test_fastnn/test_softmax.py
+34
-0
No files found.
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
View file @
a65d5009
#include <c10/cuda/CUDAGuard.h>
#include <math_constants.h>
#include <math_constants.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
#include <iostream>
...
@@ -30,7 +30,8 @@ __inline__ __device__ float WarpAllReduceSum(float val) {
...
@@ -30,7 +30,8 @@ __inline__ __device__ float WarpAllReduceSum(float val) {
////////////////
////////////////
__global__
void
fastfold_softmax_fp32
(
float
*
input
,
float
*
output
,
long
long
rows
,
long
long
cols
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax
(
T
*
input
,
T
*
output
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
@@ -41,8 +42,7 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, long long row
...
@@ -41,8 +42,7 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, long long row
if
(
threadidx_y
==
last_y
)
{
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
}
else
if
(
threadidx_y
>
last_y
)
{
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
cols_this_thread
=
0
;
}
}
...
@@ -51,72 +51,17 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, long long row
...
@@ -51,72 +51,17 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, long long row
int
lane_id
=
threadidx_y
;
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
output
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
}
...
@@ -124,74 +69,47 @@ __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output
...
@@ -124,74 +69,47 @@ __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
thread_sum
+=
buf
[
i
];
}
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
}
}
}
__global__
void
fastfold_softmax_grad_fp32
(
float
*
d_output
,
float
*
output
,
float
*
d_input
,
long
long
rows
,
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
)
{
long
long
cols
)
{
CHECK_INPUT
(
input
);
int
threadidx_x
=
threadIdx
.
x
/
32
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_d_output
=
d_output
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
at
::
Tensor
output
=
at
::
empty_like
(
input
);
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
];
dy_buf
[
i
]
=
row_d_output
[
lane_id
*
cols_per_thread
+
i
];
}
float
thread_sum
=
0.
f
;
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
#pragma unroll
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
fastfold_softmax
<
float
>
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax
<
at
::
Half
><<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
return
output
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_d_input
[
lane_id
*
cols_this_thread
+
i
]
=
(
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
];
}
}
}
}
__global__
void
fastfold_softmax_grad_bfp16
(
at
::
BFloat16
*
d_output
,
at
::
BFloat16
*
output
,
template
<
typename
T
>
at
::
BFloat16
*
d_input
,
long
long
rows
,
long
long
cols
)
{
__global__
void
fastfold_softmax_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
@@ -202,8 +120,7 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
...
@@ -202,8 +120,7 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
if
(
threadidx_y
==
last_y
)
{
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
}
else
if
(
threadidx_y
>
last_y
)
{
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
cols_this_thread
=
0
;
}
}
...
@@ -213,56 +130,37 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
...
@@ -213,56 +130,37 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
int
lane_id
=
threadidx_y
;
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_d_output
=
d_output
+
row_offset
*
cols
;
T
*
row_d_output
=
d_output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_d_input
=
d_input
+
row_offset
*
cols
;
T
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
static_cast
<
float
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
y_buf
[
i
]
=
static_cast
<
T
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
float
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
T
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
}
}
float
thread_sum
=
0.
f
;
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
static_cast
<
T
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
}
}
}
}
}
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
)
{
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
CHECK_INPUT
(
input
);
long
long
cols
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
at
::
Tensor
output
=
at
::
empty_like
(
input
);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_fp32
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
{
fastfold_softmax_bfp16
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
return
output
;
}
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
output
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
at
::
Tensor
grad_input
=
at
::
empty_like
(
output
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
output
);
...
@@ -271,11 +169,15 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
...
@@ -271,11 +169,15 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
dim3
block
(
128
);
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_grad
_fp32
<<<
grid
,
block
>>>
((
float
*
)
d_output
.
data_ptr
(),
fastfold_softmax_grad
<
float
>
<<<
grid
,
block
>>>
((
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
(
float
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
else
{
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_grad_bfp16
<<<
grid
,
block
>>>
(
fastfold_softmax_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
}
...
@@ -285,7 +187,8 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
...
@@ -285,7 +187,8 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
////////////////
////////////////
__global__
void
fastfold_softmax_scale_mask_fp32
(
float
*
input
,
float
*
mask
,
float
*
output
,
long
long
rows
,
template
<
typename
T
>
__global__
void
fastfold_softmax_scale_mask
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
...
@@ -297,8 +200,7 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
...
@@ -297,8 +200,7 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
if
(
threadidx_y
==
last_y
)
{
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
}
else
if
(
threadidx_y
>
last_y
)
{
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
cols_this_thread
=
0
;
}
}
...
@@ -307,80 +209,21 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
...
@@ -307,80 +209,21 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
int
lane_id
=
threadidx_y
;
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
1e9
;
buf
[
i
]
=
-
1
*
1e9
;
}
else
{
}
else
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
]
*
scale
;
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_scale_mask_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
mask
,
at
::
BFloat16
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
}
}
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
}
...
@@ -388,23 +231,23 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
...
@@ -388,23 +231,23 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
thread_sum
+=
buf
[
i
];
}
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
}
}
}
at
::
Tensor
fused_scale_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
at
::
Tensor
fused_scale_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
float
scale
)
{
long
long
cols
,
float
scale
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
...
@@ -415,79 +258,26 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
...
@@ -415,79 +258,26 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
dim3
block
(
128
);
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_fp32
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask
<
float
>
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
cols
,
scale
,
head
);
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_bfp16
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask
<
at
::
Half
>
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
<<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask
<
at
::
BFloat16
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
}
return
output
;
return
output
;
}
}
__global__
void
fastfold_softmax_scale_mask_grad_fp32
(
float
*
d_output
,
float
*
output
,
template
<
typename
T
>
float
*
d_input
,
float
*
mask
,
long
long
rows
,
__global__
void
fastfold_softmax_scale_mask_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
T
*
mask
,
long
long
cols
,
float
scale
,
int
head
)
{
long
long
rows
,
long
long
cols
,
float
scale
,
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
head
)
{
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_d_output
=
d_output
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
];
dy_buf
[
i
]
=
row_d_output
[
lane_id
*
cols_per_thread
+
i
];
}
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
else
{
row_d_input
=
0
;
}
}
}
}
__global__
void
fastfold_softmax_scale_mask_grad_bfp16
(
at
::
BFloat16
*
d_output
,
at
::
BFloat16
*
output
,
at
::
BFloat16
*
d_input
,
at
::
BFloat16
*
mask
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
@@ -498,8 +288,7 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
...
@@ -498,8 +288,7 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
if
(
threadidx_y
==
last_y
)
{
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
}
else
if
(
threadidx_y
>
last_y
)
{
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
cols_this_thread
=
0
;
}
}
...
@@ -509,33 +298,33 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
...
@@ -509,33 +298,33 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
int
lane_id
=
threadidx_y
;
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_d_output
=
d_output
+
row_offset
*
cols
;
T
*
row_d_output
=
d_output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_d_input
=
d_input
+
row_offset
*
cols
;
T
*
row_d_input
=
d_input
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
static_cast
<
float
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
y_buf
[
i
]
=
static_cast
<
T
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
float
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
T
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
}
}
float
thread_sum
=
0.
f
;
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]));
static_cast
<
T
>
(
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]));
}
else
{
}
else
{
row_d_input
=
0
;
row_d_input
=
0
;
}
}
...
@@ -544,7 +333,8 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
...
@@ -544,7 +333,8 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
}
}
at
::
Tensor
fused_scale_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
fused_scale_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
...
@@ -555,11 +345,16 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
...
@@ -555,11 +345,16 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
dim3
block
(
128
);
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_grad
_fp32
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_grad
<
float
>
<<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_grad_bfp16
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
scale
,
head
);
...
@@ -570,9 +365,10 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
...
@@ -570,9 +365,10 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
////////////////
////////////////
__global__
void
fastfold_softmax_scale_mask_bias_fp32
(
float
*
input
,
float
*
mask
,
float
*
bias
,
template
<
typename
T
>
float
*
output
,
long
long
rows
,
long
long
cols
,
__global__
void
fastfold_softmax_scale_mask_bias
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
float
scale
,
int
head
)
{
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
@@ -583,8 +379,7 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
...
@@ -583,8 +379,7 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
if
(
threadidx_y
==
last_y
)
{
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
}
else
if
(
threadidx_y
>
last_y
)
{
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
cols_this_thread
=
0
;
}
}
...
@@ -593,23 +388,23 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
...
@@ -593,23 +388,23 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
int
lane_id
=
threadidx_y
;
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
T
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
}
else
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
]
*
scale
+
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]
)
*
scale
;
bias_ptr
[
lane_id
*
cols_per_thread
+
i
];
buf
[
i
]
+=
static_cast
<
T
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]
)
;
}
}
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
}
...
@@ -617,78 +412,17 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
...
@@ -617,78 +412,17 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
thread_sum
+=
buf
[
i
];
}
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_scale_mask_bias_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
mask
,
at
::
BFloat16
*
bias
,
at
::
BFloat16
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
at
::
BFloat16
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
+=
static_cast
<
float
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
}
}
}
...
@@ -706,14 +440,18 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
...
@@ -706,14 +440,18 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
dim3
block
(
128
);
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_bias
_fp32
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_bias
<
float
>
<<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_bias_bfp16
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_bias
<
at
::
Half
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_bias
<
at
::
BFloat16
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
}
return
output
;
return
output
;
...
@@ -732,11 +470,16 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
...
@@ -732,11 +470,16 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
dim3
block
(
128
);
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_grad
_fp32
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_grad
<
float
>
<<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_grad_bfp16
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
scale
,
head
);
...
...
tests/test_fastnn/test_softmax.py
0 → 100644
View file @
a65d5009
import
torch
from
fastfold.model.fastnn.kernel
import
softmax
def
test_softmax
():
# [batch, dim]
test_shape
=
[[
64
,
64
],
[
64
,
128
],
[
64
,
129
],
[
64
,
1024
]]
test_dtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
test_device
=
torch
.
device
(
"cuda"
)
tolerance_eps
=
{
torch
.
float32
:
10e-5
,
torch
.
float16
:
10e-2
,
torch
.
bfloat16
:
10e-2
}
for
shape
in
test_shape
:
for
dtype
in
test_dtype
:
sample_input
=
torch
.
rand
(
shape
).
to
(
device
=
test_device
,
dtype
=
dtype
).
requires_grad_
(
True
)
sample_input_fastnn
=
torch
.
clone
(
sample_input
.
detach
()).
requires_grad_
(
True
)
# Forward
torch_out
=
torch
.
nn
.
functional
.
softmax
(
sample_input
,
dim
=-
1
)
fastnn_out
=
softmax
(
sample_input_fastnn
)
forward_error
=
torch
.
max
(
torch
.
abs
(
torch_out
-
fastnn_out
)).
cpu
().
item
()
assert
forward_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
# Backward
out_grad
=
torch
.
rand_like
(
torch_out
).
requires_grad_
(
False
)
torch_out
.
backward
(
out_grad
)
fastnn_out
.
backward
(
out_grad
)
backward_error
=
torch
.
max
(
torch
.
abs
(
sample_input
.
grad
-
sample_input_fastnn
.
grad
)).
cpu
().
item
()
assert
backward_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment