Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d73def35
"tests/vscode:/vscode.git/clone" did not exist on "eed9cced8e6ac9ff88bfdbaf2d73419add01e838"
Commit
d73def35
authored
Sep 18, 2020
by
root
Browse files
scaling-generalized masking-softmax fusion kernel
parent
19301985
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
723 additions
and
37 deletions
+723
-37
megatron/arguments.py
megatron/arguments.py
+9
-1
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+35
-10
megatron/fused_kernels/scaled_masked_softmax.cpp
megatron/fused_kernels/scaled_masked_softmax.cpp
+74
-0
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+452
-0
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+102
-0
megatron/model/bert_model.py
megatron/model/bert_model.py
+6
-15
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+44
-11
megatron/model/transformer.py
megatron/model/transformer.py
+1
-0
No files found.
megatron/arguments.py
View file @
d73def35
...
@@ -122,6 +122,10 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -122,6 +122,10 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
scaled_upper_triang_masked_softmax_fusion
:
if
args
.
scaled_upper_triang_masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
# load scaled_masked_softmax_fusion kernel
if
args
.
scaled_masked_softmax_fusion
:
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
@@ -228,7 +232,11 @@ def _add_training_args(parser):
...
@@ -228,7 +232,11 @@ def _add_training_args(parser):
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Enable fusion of query_key_value_scaling '
help
=
'Enable fusion of query_key_value_scaling '
'time (upper diagonal) masking, softmax.'
)
'time (upper diagonal) masking and softmax.'
)
group
.
add_argument
(
'--scaled-masked-softmax-fusion'
,
action
=
'store_true'
,
help
=
'Enable fusion of query_key_value_scaling '
'general masking and softmax.'
)
group
.
add_argument
(
'--bias-gelu-fusion'
,
action
=
'store_true'
,
group
.
add_argument
(
'--bias-gelu-fusion'
,
action
=
'store_true'
,
help
=
'Enable bias and gelu fusion.'
)
help
=
'Enable bias and gelu fusion.'
)
group
.
add_argument
(
'--bias-dropout-fusion'
,
action
=
'store_true'
,
group
.
add_argument
(
'--bias-dropout-fusion'
,
action
=
'store_true'
,
...
...
megatron/fused_kernels/__init__.py
View file @
d73def35
...
@@ -17,9 +17,7 @@ import pathlib
...
@@ -17,9 +17,7 @@ import pathlib
import
subprocess
import
subprocess
from
torch.utils
import
cpp_extension
from
torch.utils
import
cpp_extension
def
load_scaled_upper_triang_masked_softmax_fusion_kernel
():
def
get_cuda_bare_metal_version
(
cuda_dir
):
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
universal_newlines
=
True
)
output
=
raw_output
.
split
()
output
=
raw_output
.
split
()
...
@@ -30,6 +28,9 @@ def load_scaled_upper_triang_masked_softmax_fusion_kernel():
...
@@ -30,6 +28,9 @@ def load_scaled_upper_triang_masked_softmax_fusion_kernel():
return
raw_output
,
bare_metal_major
,
bare_metal_minor
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
load_scaled_upper_triang_masked_softmax_fusion_kernel
():
# Check, if CUDA11 is installed for compute capability 8.0
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
...
@@ -51,3 +52,27 @@ def load_scaled_upper_triang_masked_softmax_fusion_kernel():
...
@@ -51,3 +52,27 @@ def load_scaled_upper_triang_masked_softmax_fusion_kernel():
'--expt-extended-lambda'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
,
'--use_fast_math'
]
+
cc_flag
,
verbose
=
True
)
verbose
=
True
)
def
load_scaled_masked_softmax_fusion_kernel
():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
name
=
'scaled_masked_softmax_cuda'
,
sources
=
[
srcpath
/
'scaled_masked_softmax.cpp'
,
srcpath
/
'scaled_masked_softmax_cuda.cu'
],
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
,
verbose
=
True
)
megatron/fused_kernels/scaled_masked_softmax.cpp
0 → 100644
View file @
d73def35
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
}
megatron/fused_kernels/scaled_masked_softmax.h
0 → 100644
View file @
d73def35
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batches
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
pad_first_batch
=
0
;
if
(
pad_batches
!=
1
)
{
// bert style
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
else
{
// gpt2 style
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
mask
+=
pad_first_batch
*
stride
+
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
mask
[
itr_idx
]
!=
1
)
{
elements
[
i
][
it
]
=
(
acc_t
)
src
[
itr_idx
]
*
scale
;
}
else
{
elements
[
i
][
it
]
=
-
10000.0
;
}
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
+
it
*
WARP_SIZE
];
}
else
{
output_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
}
}
}
}
}
// end of anonymous namespace
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
input_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batches
,
int
attn_heads
,
int
pad_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
batches
*
attn_heads
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batches
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
batches
*
attn_heads
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
batch_count
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
0 → 100644
View file @
d73def35
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
seq_len
=
input
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
seq_len
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
seq_len
,
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_masked_softmax_forward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
seq_len
,
seq_len
,
batches
,
attn_heads
,
pad_batches
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
seq_len
=
output_grads
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
2
)
==
output_grads
.
size
(
3
));
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
dispatch_scaled_masked_softmax_backward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
batches
,
attn_heads
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
megatron/model/bert_model.py
View file @
d73def35
...
@@ -28,13 +28,11 @@ from megatron.model.utils import init_method_normal
...
@@ -28,13 +28,11 @@ from megatron.model.utils import init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
def
bert_attention_mask_func
(
attention_scores
,
attention_mask
):
def
bert_attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
=
attention_scores
+
attention_mask
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
return
attention_scores
def
bert_extended_attention_mask
(
attention_mask
):
def
bert_extended_attention_mask
(
attention_mask
,
dtype
):
# We create a 3D attention mask from a 2D tensor mask.
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
# [b, 1, s]
attention_mask_b1s
=
attention_mask
.
unsqueeze
(
1
)
attention_mask_b1s
=
attention_mask
.
unsqueeze
(
1
)
...
@@ -44,17 +42,11 @@ def bert_extended_attention_mask(attention_mask, dtype):
...
@@ -44,17 +42,11 @@ def bert_extended_attention_mask(attention_mask, dtype):
attention_mask_bss
=
attention_mask_b1s
*
attention_mask_bs1
attention_mask_bss
=
attention_mask_b1s
*
attention_mask_bs1
# [b, 1, s, s]
# [b, 1, s, s]
extended_attention_mask
=
attention_mask_bss
.
unsqueeze
(
1
)
extended_attention_mask
=
attention_mask_bss
.
unsqueeze
(
1
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
dtype
)
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
return
extended_attention_mask
# Convert attention mask to binary:
extended_attention_mask
=
(
extended_attention_mask
<
0.5
)
return
extended_attention_mask
def
bert_position_ids
(
token_ids
):
def
bert_position_ids
(
token_ids
):
# Create position ids
# Create position ids
...
@@ -143,8 +135,7 @@ class BertModel(MegatronModule):
...
@@ -143,8 +135,7 @@ class BertModel(MegatronModule):
def
forward
(
self
,
input_ids
,
attention_mask
,
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
tokentype_ids
=
None
,
lm_labels
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
attention_mask
,
next
(
self
.
language_model
.
parameters
()).
dtype
)
position_ids
=
bert_position_ids
(
input_ids
)
position_ids
=
bert_position_ids
(
input_ids
)
if
self
.
add_binary_head
:
if
self
.
add_binary_head
:
...
...
megatron/model/fused_softmax.py
View file @
d73def35
...
@@ -43,6 +43,34 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
...
@@ -43,6 +43,34 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
scale_t
[
0
])
scale_t
[
0
])
return
input_grads
,
None
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
import
scaled_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
\
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
\
scaled_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
"""
"""
fused operation: scaling + mask + softmax
fused operation: scaling + mask + softmax
...
@@ -55,11 +83,12 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -55,11 +83,12 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
scale: scaling factor used in input tensor scaling.
scale: scaling factor used in input tensor scaling.
"""
"""
def
__init__
(
self
,
input_in_fp16
,
upper_triang_mask
,
def
__init__
(
self
,
input_in_fp16
,
upper_triang_mask
_fusion
,
mask_func
,
softmax_in_fp32
,
scale
):
general_mask_fusion
,
mask_func
,
softmax_in_fp32
,
scale
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
upper_triang_mask
=
upper_triang_mask
self
.
upper_triang_mask_fusion
=
upper_triang_mask_fusion
self
.
general_mask_fusion
=
general_mask_fusion
self
.
mask_func
=
mask_func
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
self
.
scale
=
scale
...
@@ -72,20 +101,24 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -72,20 +101,24 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
data_size
=
input
.
size
()
data_size
=
input
.
size
()
assert
input
.
dim
()
==
4
assert
input
.
dim
()
==
4
# invoke custom kernel
for implicit uuper triangular masking
# invoke custom kernel
if
self
.
input_in_fp16
and
self
.
upper_triang_mask
and
\
if
self
.
input_in_fp16
and
data_size
[
-
1
]
<=
2048
and
\
data_size
[
-
1
]
<=
2048
and
input
.
size
()[
2
]
==
input
.
size
()[
3
]:
(
self
.
upper_triang_mask_fusion
or
self
.
general_mask_fusion
)
and
\
input
=
input
.
view
(
-
1
,
data_
size
[
2
]
,
data_
size
[
3
]
)
input
.
size
()
[
2
]
==
input
.
size
()
[
3
]
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
upper_triang_mask_fusion
:
input
=
input
.
view
(
-
1
,
data_size
[
2
],
data_size
[
3
])
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
probs
=
probs
.
view
(
*
data_size
)
else
:
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
else
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
input
=
input
.
float
()
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
self
.
scale
is
not
None
:
if
self
.
scale
is
not
None
:
mask_output
=
mask_output
*
self
.
scale
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
...
...
megatron/model/transformer.py
View file @
d73def35
...
@@ -153,6 +153,7 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -153,6 +153,7 @@ class ParallelSelfAttention(MegatronModule):
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
fp16
,
args
.
scaled_upper_triang_masked_softmax_fusion
,
args
.
scaled_upper_triang_masked_softmax_fusion
,
args
.
scaled_masked_softmax_fusion
,
self
.
attention_mask_func
,
self
.
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
self
.
attention_softmax_in_fp32
,
coeff
)
coeff
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment