Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
f44557ed
Commit
f44557ed
authored
Jul 25, 2022
by
shenggan
Browse files
refactor softmax kernel
parent
a65d5009
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
471 additions
and
137 deletions
+471
-137
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
...del/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
+468
-134
setup.py
setup.py
+1
-1
tests/test_fastnn/test_softmax.py
tests/test_fastnn/test_softmax.py
+2
-2
No files found.
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
View file @
f44557ed
...
...
@@ -2,6 +2,7 @@
#include <math_constants.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include <iostream>
#include "ATen/ATen.h"
...
...
@@ -28,25 +29,65 @@ __inline__ __device__ float WarpAllReduceSum(float val) {
return
val
;
}
////////////////
inline
cudaError_t
GetNumBlocks
(
int64_t
block_size
,
int64_t
max_blocks
,
int64_t
waves
,
int
*
num_blocks
)
{
int
dev
;
{
cudaError_t
err
=
cudaGetDevice
(
&
dev
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
}
int
sm_count
;
{
cudaError_t
err
=
cudaDeviceGetAttribute
(
&
sm_count
,
cudaDevAttrMultiProcessorCount
,
dev
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
}
int
tpm
;
{
cudaError_t
err
=
cudaDeviceGetAttribute
(
&
tpm
,
cudaDevAttrMaxThreadsPerMultiProcessor
,
dev
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
}
*
num_blocks
=
std
::
max
<
int
>
(
1
,
std
::
min
<
int64_t
>
(
max_blocks
,
sm_count
*
tpm
/
block_size
*
waves
));
return
cudaSuccess
;
}
template
<
typename
T
>
struct
SumOp
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
MaxOp
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
max
(
a
,
b
);
}
};
template
<
template
<
typename
>
class
ReductionOp
,
typename
T
,
int
block_size
>
__inline__
__device__
T
BlockAllReduce
(
T
val
)
{
typedef
cub
::
BlockReduce
<
T
,
block_size
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
T
result_broadcast
;
T
result
=
BlockReduce
(
temp_storage
).
Reduce
(
val
,
ReductionOp
<
T
>
());
if
(
threadIdx
.
x
==
0
)
{
result_broadcast
=
result
;
}
__syncthreads
();
return
result_broadcast
;
}
////////////////
template
<
typename
T
,
int
cols_per_thread
>
__global__
void
fastfold_softmax
(
T
*
input
,
T
*
output
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
float
buf
[
cols_per_thread
];
int
lane_id
=
threadidx_y
;
...
...
@@ -57,12 +98,16 @@ __global__ void fastfold_softmax(T *input, T *output, long long rows, long long
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
else
{
buf
[
i
]
=
-
1
*
CUDART_INF_F
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
i
++
)
{
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
...
...
@@ -70,16 +115,47 @@ __global__ void fastfold_softmax(T *input, T *output, long long rows, long long
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_
this
_thread
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cols_
per
_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
}
template
<
typename
T
,
int
block_size
>
__global__
void
fastfold_softmax_sm
(
T
*
input
,
T
*
output
,
long
long
rows
,
long
long
cols
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
for
(
int64_t
row
=
blockIdx
.
x
;
row
<
rows
;
row
+=
gridDim
.
x
)
{
float
thread_max
=
-
1
*
CUDART_INF_F
;
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
buf
[
id
]
=
static_cast
<
T
>
(
input
[
row
*
cols
+
id
]);
thread_max
=
max
(
thread_max
,
buf
[
id
]);
}
const
float
row_max
=
BlockAllReduce
<
MaxOp
,
float
,
block_size
>
(
thread_max
);
float
thread_sum
=
0
;
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
buf
[
id
]
=
__expf
(
buf
[
id
]
-
row_max
);
thread_sum
+=
buf
[
id
];
}
const
float
row_sum
=
BlockAllReduce
<
SumOp
,
float
,
block_size
>
(
thread_sum
);
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
output
[
row
*
cols
+
id
]
=
static_cast
<
T
>
(
buf
[
id
]
/
row_sum
);
}
}
}
...
...
@@ -93,17 +169,83 @@ at::Tensor softmax(at::Tensor input, long long rows, long long cols) {
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax
<
float
>
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax
<
at
::
Half
><<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
if
(
cols
<=
32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax
<
float
,
1
><<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax
<
at
::
Half
,
1
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax
<
at
::
BFloat16
,
1
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)output.data_ptr(), rows, cols); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax<at::Half, col_per_thread><<<grid, block>>>( \
(at::Half *)input.data_ptr(), (at::Half *)output.data_ptr(), rows, cols); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)output.data_ptr(), rows, cols); \
} \
}
COLS_CASE
(
2
)
COLS_CASE
(
3
)
COLS_CASE
(
4
)
COLS_CASE
(
5
)
COLS_CASE
(
6
)
COLS_CASE
(
7
)
COLS_CASE
(
8
)
COLS_CASE
(
9
)
COLS_CASE
(
10
)
COLS_CASE
(
11
)
COLS_CASE
(
12
)
COLS_CASE
(
13
)
COLS_CASE
(
14
)
COLS_CASE
(
15
)
COLS_CASE
(
16
)
COLS_CASE
(
17
)
COLS_CASE
(
18
)
COLS_CASE
(
19
)
COLS_CASE
(
20
)
COLS_CASE
(
21
)
COLS_CASE
(
22
)
COLS_CASE
(
23
)
COLS_CASE
(
24
)
COLS_CASE
(
25
)
COLS_CASE
(
26
)
COLS_CASE
(
27
)
COLS_CASE
(
28
)
COLS_CASE
(
29
)
COLS_CASE
(
30
)
COLS_CASE
(
31
)
COLS_CASE
(
32
)
#undef COLS_CASE
else
{
int
grid_dim
;
constexpr
int
waves
=
32
;
GetNumBlocks
(
128
,
rows
,
waves
,
&
grid_dim
);
dim3
block
(
128
);
const
size_t
smem
=
cols
*
sizeof
(
float
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_sm
<
float
,
128
><<<
grid_dim
,
block
,
smem
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_sm
<
at
::
Half
,
128
><<<
grid_dim
,
block
,
smem
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_sm
<
at
::
BFloat16
,
128
><<<
grid_dim
,
block
,
smem
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
}
return
output
;
}
...
...
@@ -124,8 +266,8 @@ __global__ void fastfold_softmax_grad(T *d_output, T *output, T *d_input, long l
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
float
y_buf
[
8
];
float
dy_buf
[
8
];
int
lane_id
=
threadidx_y
;
...
...
@@ -187,61 +329,83 @@ at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, long long ro
////////////////
template
<
typename
T
>
template
<
typename
T
,
int
cols_per_thread
>
__global__
void
fastfold_softmax_scale_mask
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
)
;
float
buf
[
cols_per_thread
]
;
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
int
lane_id
=
threadidx_y
;
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
1e9
;
}
else
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
}
}
float
buf
[
32
];
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
int
lane_id
=
thread
idx_y
;
float
warp_max
=
WarpAllReduceMax
(
thread
_max
)
;
if
(
row_offset
<
rows
)
{
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
1e9
;
}
else
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
}
}
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
template
<
typename
T
,
int
block_size
>
__global__
void
fastfold_softmax_scale_mask_sm
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
for
(
int64_t
row
=
blockIdx
.
x
;
row
<
rows
;
row
+=
gridDim
.
x
)
{
T
*
mask_ptr
=
mask
+
((
row
/
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
if
(
mask_ptr
[
id
]
==
0
)
{
buf
[
id
]
=
-
1
*
1e9
;
}
else
{
buf
[
id
]
=
input
[
row
*
cols
+
id
]
*
scale
;
}
thread_max
=
max
(
thread_max
,
buf
[
id
]);
}
float
warp
_max
=
Warp
AllReduceMax
(
thread_max
);
const
float
row
_max
=
Block
AllReduce
<
Max
Op
,
float
,
block_size
>
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
float
thread_sum
=
0
;
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
buf
[
id
]
=
__expf
(
buf
[
id
]
-
row_max
);
thread_sum
+=
buf
[
id
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
const
float
row_sum
=
BlockAllReduce
<
SumOp
,
float
,
block_size
>
(
thread_sum
);
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
output
[
row
*
cols
+
id
]
=
buf
[
id
]
/
row_sum
;
}
}
}
...
...
@@ -252,26 +416,97 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
int
head
=
input
.
sizes
()[
2
];
at
::
Tensor
output
=
at
::
empty_like
(
input
);
//
at::Tensor output = at::empty_like(input);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask
<
float
>
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask
<
at
::
BFloat16
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
if
(
cols
<=
32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask
<
float
,
1
>
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask
<
at
::
Half
,
1
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask
<
at
::
BFloat16
,
1
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
}
return
output
;
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_scale_mask<float, col_per_thread> \
<<<grid, block>>>((float *)input.data_ptr(), (float *)mask.data_ptr(), \
(float *)input.data_ptr(), rows, cols, scale, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_scale_mask<at::Half, col_per_thread> \
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)input.data_ptr(), rows, cols, scale, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_scale_mask<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)input.data_ptr(), rows, cols, scale, head); \
} \
}
COLS_CASE
(
2
)
COLS_CASE
(
3
)
COLS_CASE
(
4
)
COLS_CASE
(
5
)
COLS_CASE
(
6
)
COLS_CASE
(
7
)
COLS_CASE
(
8
)
COLS_CASE
(
9
)
COLS_CASE
(
10
)
COLS_CASE
(
11
)
COLS_CASE
(
12
)
COLS_CASE
(
13
)
COLS_CASE
(
14
)
COLS_CASE
(
15
)
COLS_CASE
(
16
)
COLS_CASE
(
17
)
COLS_CASE
(
18
)
COLS_CASE
(
19
)
COLS_CASE
(
20
)
COLS_CASE
(
21
)
COLS_CASE
(
22
)
COLS_CASE
(
23
)
COLS_CASE
(
24
)
COLS_CASE
(
25
)
COLS_CASE
(
26
)
COLS_CASE
(
27
)
COLS_CASE
(
28
)
COLS_CASE
(
29
)
COLS_CASE
(
30
)
COLS_CASE
(
31
)
COLS_CASE
(
32
)
#undef COLS_CASE
else
{
int
grid_dim
;
constexpr
int
waves
=
32
;
GetNumBlocks
(
128
,
rows
,
waves
,
&
grid_dim
);
dim3
block
(
128
);
const
size_t
smem
=
cols
*
sizeof
(
float
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_sm
<
float
,
128
>
<<<
grid
,
block
,
smem
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_sm
<
at
::
Half
,
128
>
<<<
grid
,
block
,
smem
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
}
return
input
;
}
template
<
typename
T
>
...
...
@@ -292,8 +527,8 @@ __global__ void fastfold_softmax_scale_mask_grad(T *d_output, T *output, T *d_in
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
float
y_buf
[
8
];
float
dy_buf
[
8
];
int
lane_id
=
threadidx_y
;
...
...
@@ -365,64 +600,88 @@ at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor out
////////////////
template
<
typename
T
>
template
<
typename
T
,
int
cols_per_thread
>
__global__
void
fastfold_softmax_scale_mask_bias
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
)
;
float
buf
[
cols_per_thread
]
;
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
int
lane_id
=
threadidx_y
;
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
+=
static_cast
<
T
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
}
}
float
buf
[
32
];
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
int
lane_id
=
thread
idx_y
;
float
warp_max
=
WarpAllReduceMax
(
thread
_max
)
;
if
(
row_offset
<
rows
)
{
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
*
scale
;
buf
[
i
]
+=
static_cast
<
T
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
}
}
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
template
<
typename
T
,
int
block_size
>
__global__
void
fastfold_softmax_scale_mask_bias_sm
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
cols
,
float
scale
,
int
head
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
for
(
int64_t
row
=
blockIdx
.
x
;
row
<
rows
;
row
+=
gridDim
.
x
)
{
T
*
mask_ptr
=
mask
+
((
row
/
(
head
*
cols
))
*
cols
);
T
*
bias_ptr
=
bias
+
((
row
%
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
if
(
mask_ptr
[
id
]
==
0
)
{
buf
[
id
]
=
-
1
*
1e9
;
}
else
{
buf
[
id
]
=
input
[
row
*
cols
+
id
]
*
scale
+
bias_ptr
[
id
];
}
thread_max
=
max
(
thread_max
,
buf
[
id
]);
}
float
warp
_max
=
Warp
AllReduceMax
(
thread_max
);
const
float
row
_max
=
Block
AllReduce
<
Max
Op
,
float
,
block_size
>
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
float
thread_sum
=
0
;
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
buf
[
id
]
=
__expf
(
buf
[
id
]
-
row_max
);
thread_sum
+=
buf
[
id
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
const
float
row_sum
=
BlockAllReduce
<
SumOp
,
float
,
block_size
>
(
thread_sum
);
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
output
[
row
*
cols
+
id
]
=
buf
[
id
]
/
row_sum
;
}
}
}
...
...
@@ -434,27 +693,102 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
CHECK_INPUT
(
bias
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
int
head
=
input
.
sizes
()[
2
];
at
::
Tensor
output
=
at
::
empty_like
(
input
);
//
at::Tensor output = at::empty_like(input);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_bias
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_bias
<
at
::
Half
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_bias
<
at
::
BFloat16
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
if
(
cols
<=
32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_bias
<
float
,
1
><<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_bias
<
at
::
Half
,
1
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_bias
<
at
::
BFloat16
,
1
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_scale_mask_bias<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(), \
(float *)input.data_ptr(), rows, cols, scale, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_scale_mask_bias<at::Half, col_per_thread> \
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, \
cols, scale, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_scale_mask_bias<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols, \
scale, head); \
} \
}
COLS_CASE
(
2
)
COLS_CASE
(
3
)
COLS_CASE
(
4
)
COLS_CASE
(
5
)
COLS_CASE
(
6
)
COLS_CASE
(
7
)
COLS_CASE
(
8
)
COLS_CASE
(
9
)
COLS_CASE
(
10
)
COLS_CASE
(
11
)
COLS_CASE
(
12
)
COLS_CASE
(
13
)
COLS_CASE
(
14
)
COLS_CASE
(
15
)
COLS_CASE
(
16
)
COLS_CASE
(
17
)
COLS_CASE
(
18
)
COLS_CASE
(
19
)
COLS_CASE
(
20
)
COLS_CASE
(
21
)
COLS_CASE
(
22
)
COLS_CASE
(
23
)
COLS_CASE
(
24
)
COLS_CASE
(
25
)
COLS_CASE
(
26
)
COLS_CASE
(
27
)
COLS_CASE
(
28
)
COLS_CASE
(
29
)
COLS_CASE
(
30
)
COLS_CASE
(
31
)
COLS_CASE
(
32
)
#undef COLS_CASE
else
{
int
grid_dim
;
constexpr
int
waves
=
32
;
GetNumBlocks
(
128
,
rows
,
waves
,
&
grid_dim
);
dim3
block
(
128
);
const
size_t
smem
=
cols
*
sizeof
(
float
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_scale_mask_bias_sm
<
float
,
128
><<<
grid
,
block
,
smem
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_scale_mask_bias_sm
<
at
::
Half
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_scale_mask_bias_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
scale
,
head
);
}
}
return
out
put
;
return
in
put
;
}
at
::
Tensor
fused_scale_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
...
...
setup.py
View file @
f44557ed
...
...
@@ -87,7 +87,7 @@ if CUDA_HOME is None:
"Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc."
)
else
:
check_cuda_torch_binary_vs_bare_metal
(
CUDA_HOME
)
#
check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
def
cuda_ext_helper
(
name
,
sources
,
extra_cuda_flags
):
return
CUDAExtension
(
...
...
tests/test_fastnn/test_softmax.py
View file @
f44557ed
...
...
@@ -5,11 +5,11 @@ from fastfold.model.fastnn.kernel import softmax
def
test_softmax
():
# [batch, dim]
test_shape
=
[[
64
,
64
],
[
64
,
128
],
[
64
,
129
],
[
64
,
1024
]]
test_shape
=
[[
64
,
64
],
[
64
,
128
],
[
64
,
129
],
[
64
,
2000
]]
test_dtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
test_device
=
torch
.
device
(
"cuda"
)
tolerance_eps
=
{
torch
.
float32
:
10e-
5
,
torch
.
float16
:
10e-2
,
torch
.
bfloat16
:
10e-2
}
tolerance_eps
=
{
torch
.
float32
:
10e-
4
,
torch
.
float16
:
10e-2
,
torch
.
bfloat16
:
10e-2
}
for
shape
in
test_shape
:
for
dtype
in
test_dtype
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment