Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
a65d5009
"test/config/examples/mnist-annotation.yml" did not exist on "2fa77bccdb36c5ccab2ac4e0ad30ee6884d96c19"
Unverified
Commit
a65d5009
authored
Jun 03, 2022
by
shenggan
Committed by
GitHub
Jun 03, 2022
Browse files
use template for fused softmax & add unittest for fused softmax (#26)
parent
771d4b83
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
171 additions
and
394 deletions
+171
-394
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
...del/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
+137
-394
tests/test_fastnn/test_softmax.py
tests/test_fastnn/test_softmax.py
+34
-0
No files found.
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
View file @
a65d5009
#include <c10/cuda/CUDAGuard.h>
#include <math_constants.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
...
...
@@ -30,7 +30,8 @@ __inline__ __device__ float WarpAllReduceSum(float val) {
////////////////
__global__
void
fastfold_softmax_fp32
(
float
*
input
,
float
*
output
,
long
long
rows
,
long
long
cols
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax
(
T
*
input
,
T
*
output
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -41,8 +42,7 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, long long row
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
...
...
@@ -51,72 +51,17 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, long long row
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
output
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
...
...
@@ -124,74 +69,47 @@ __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
__global__
void
fastfold_softmax_grad_fp32
(
float
*
d_output
,
float
*
output
,
float
*
d_input
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_d_output
=
d_output
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
];
dy_buf
[
i
]
=
row_d_output
[
lane_id
*
cols_per_thread
+
i
];
}
float
thread_sum
=
0.
f
;
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
input
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
at
::
Tensor
output
=
at
::
empty_like
(
input
);
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_d_input
[
lane_id
*
cols_this_thread
+
i
]
=
(
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
];
}
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax
<
float
>
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax
<
at
::
Half
><<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
return
output
;
}
__global__
void
fastfold_softmax_grad_bfp16
(
at
::
BFloat16
*
d_output
,
at
::
BFloat16
*
output
,
at
::
BFloat16
*
d_input
,
long
long
rows
,
long
long
cols
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -202,8 +120,7 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
...
...
@@ -213,56 +130,37 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_d_output
=
d_output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_d_input
=
d_input
+
row_offset
*
cols
;
T
*
row_d_output
=
d_output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
static_cast
<
float
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
float
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
y_buf
[
i
]
=
static_cast
<
T
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
T
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
}
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
static_cast
<
T
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
}
}
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
input
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
at
::
Tensor
output
=
at
::
empty_like
(
input
);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_fp32
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
{
fastfold_softmax_bfp16
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
return
output
;
}
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
)
{
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
output
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
at
::
Tensor
grad_input
=
at
::
empty_like
(
output
);
...
...
@@ -271,11 +169,15 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_grad_fp32
<<<
grid
,
block
>>>
((
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
else
{
fastfold_softmax_grad_bfp16
<<<
grid
,
block
>>>
(
fastfold_softmax_grad
<
float
><<<
grid
,
block
>>>
((
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
...
...
@@ -285,8 +187,9 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
////////////////
__global__
void
fastfold_softmax_scale_mask_fp32
(
float
*
input
,
float
*
mask
,
float
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax_scale_mask
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -297,8 +200,7 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
...
...
@@ -307,80 +209,21 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
1e9
;
}
else
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
]
*
scale
;
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_scale_mask_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
mask
,
at
::
BFloat16
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
...
...
@@ -388,23 +231,23 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
at
::
Tensor
fused_scale_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
at
::
Tensor
fused_scale_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
...
...
@@ -415,79 +258,26 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_fp32
<<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
fastfold_softmax_scale_mask_bfp16
<<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
fastfold_softmax_scale_mask
<
float
>
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask
<
at
::
BFloat16
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
return
output
;
}
__global__
void
fastfold_softmax_scale_mask_grad_fp32
(
float
*
d_output
,
float
*
output
,
float
*
d_input
,
float
*
mask
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_d_output
=
d_output
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
row_output
[
lane_id
*
cols_per_thread
+
i
];
dy_buf
[
i
]
=
row_d_output
[
lane_id
*
cols_per_thread
+
i
];
}
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
else
{
row_d_input
=
0
;
}
}
}
}
__global__
void
fastfold_softmax_scale_mask_grad_bfp16
(
at
::
BFloat16
*
d_output
,
at
::
BFloat16
*
output
,
at
::
BFloat16
*
d_input
,
at
::
BFloat16
*
mask
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax_scale_mask_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
T
*
mask
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -498,8 +288,7 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
...
...
@@ -509,33 +298,33 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_d_output
=
d_output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
row_d_input
=
d_input
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
row_d_output
=
d_output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_d_input
=
d_input
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
static_cast
<
float
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
float
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
y_buf
[
i
]
=
static_cast
<
T
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
T
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
}
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]));
static_cast
<
T
>
(
scale
*
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]));
}
else
{
row_d_input
=
0
;
}
...
...
@@ -544,7 +333,8 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
}
at
::
Tensor
fused_scale_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
,
float
scale
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
...
...
@@ -555,11 +345,16 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_grad
_fp32
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_grad
<
float
>
<<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
fastfold_softmax_scale_mask_grad_bfp16
<<<
grid
,
block
>>>
(
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
...
...
@@ -570,9 +365,10 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
////////////////
__global__
void
fastfold_softmax_scale_mask_bias_fp32
(
float
*
input
,
float
*
mask
,
float
*
bias
,
float
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
template
<
typename
T
>
__global__
void
fastfold_softmax_scale_mask_bias
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
...
...
@@ -583,69 +379,7 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
*
row_input
=
input
+
row_offset
*
cols
;
float
*
row_output
=
output
+
row_offset
*
cols
;
float
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
row_input
[
lane_id
*
cols_per_thread
+
i
]
*
scale
+
bias_ptr
[
lane_id
*
cols_per_thread
+
i
];
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
__fdividef
(
buf
[
i
],
warp_sum
);
}
}
}
__global__
void
fastfold_softmax_scale_mask_bias_bfp16
(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
mask
,
at
::
BFloat16
*
bias
,
at
::
BFloat16
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
...
...
@@ -654,23 +388,23 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
at
::
BFloat16
*
row_input
=
input
+
row_offset
*
cols
;
at
::
BFloat16
*
row_output
=
output
+
row_offset
*
cols
;
at
::
BFloat16
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
at
::
BFloat16
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
+=
static_cast
<
float
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
+=
static_cast
<
T
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
...
...
@@ -678,17 +412,17 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
at
::
BFloat16
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
...
...
@@ -706,14 +440,18 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_bias
_fp32
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_bias
<
float
>
<<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
fastfold_softmax_scale_mask_bias_bfp16
<<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_bias
<
at
::
Half
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_bias
<
at
::
BFloat16
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
return
output
;
...
...
@@ -732,11 +470,16 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_grad
_fp32
<<<
grid
,
block
>>>
(
fastfold_softmax_scale_mask_grad
<
float
>
<<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
{
fastfold_softmax_scale_mask_grad_bfp16
<<<
grid
,
block
>>>
(
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
...
...
tests/test_fastnn/test_softmax.py
0 → 100644
View file @
a65d5009
import
torch
from
fastfold.model.fastnn.kernel
import
softmax
def
test_softmax
():
# [batch, dim]
test_shape
=
[[
64
,
64
],
[
64
,
128
],
[
64
,
129
],
[
64
,
1024
]]
test_dtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
test_device
=
torch
.
device
(
"cuda"
)
tolerance_eps
=
{
torch
.
float32
:
10e-5
,
torch
.
float16
:
10e-2
,
torch
.
bfloat16
:
10e-2
}
for
shape
in
test_shape
:
for
dtype
in
test_dtype
:
sample_input
=
torch
.
rand
(
shape
).
to
(
device
=
test_device
,
dtype
=
dtype
).
requires_grad_
(
True
)
sample_input_fastnn
=
torch
.
clone
(
sample_input
.
detach
()).
requires_grad_
(
True
)
# Forward
torch_out
=
torch
.
nn
.
functional
.
softmax
(
sample_input
,
dim
=-
1
)
fastnn_out
=
softmax
(
sample_input_fastnn
)
forward_error
=
torch
.
max
(
torch
.
abs
(
torch_out
-
fastnn_out
)).
cpu
().
item
()
assert
forward_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
# Backward
out_grad
=
torch
.
rand_like
(
torch_out
).
requires_grad_
(
False
)
torch_out
.
backward
(
out_grad
)
fastnn_out
.
backward
(
out_grad
)
backward_error
=
torch
.
max
(
torch
.
abs
(
sample_input
.
grad
-
sample_input_fastnn
.
grad
)).
cpu
().
item
()
assert
backward_error
<
tolerance_eps
[
dtype
],
f
"Error when
{
shape
}
{
dtype
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment