Commit 88173a1a authored by Tri Dao's avatar Tri Dao
Browse files

[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP

parent 780e8eea
...@@ -28,19 +28,19 @@ ...@@ -28,19 +28,19 @@
} }
template <typename T> template <typename T>
int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace); int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias);
template <typename T> template <typename T>
int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ; int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act);
template <typename T> template <typename T>
int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace); int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias);
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) { std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
int batch_size = input.size(0); int64_t batch_size = input.size(0);
int in_features = input.size(1); int64_t in_features = input.size(1);
int out_features = d_output.size(1); int64_t out_features = d_output.size(1);
TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.dtype() == d_output.dtype()); TORCH_CHECK(input.dtype() == d_output.dtype());
...@@ -66,8 +66,6 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, ...@@ -66,8 +66,6 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
d_bias = at::empty({out_features}, opts); d_bias = at::empty({out_features}, opts);
#endif #endif
} }
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] { DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
auto result = linear_bias_wgrad_cuda<scalar_t>( auto result = linear_bias_wgrad_cuda<scalar_t>(
...@@ -77,21 +75,20 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, ...@@ -77,21 +75,20 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
batch_size, batch_size,
out_features, out_features,
d_weight.data_ptr<scalar_t>(), d_weight.data_ptr<scalar_t>(),
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr, has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr);
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_wgrad failed."); TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
}); });
return {d_weight, d_bias}; return {d_weight, d_bias};
} }
std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
c10::optional<at::Tensor> bias_, c10::optional<at::Tensor> bias_,
bool save_gelu_in, int heuristic) { bool is_gelu, bool save_pre_act, int heuristic) {
int batch_size = input.size(0); int64_t batch_size = input.size(0);
int in_features = input.size(1); int64_t in_features = input.size(1);
int out_features = weight.size(0); int64_t out_features = weight.size(0);
TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.dtype() == weight.dtype()); TORCH_CHECK(input.dtype() == weight.dtype());
...@@ -116,51 +113,52 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, ...@@ -116,51 +113,52 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
// create output/workspace tensor // create output/workspace tensor
auto opts = input.options(); auto opts = input.options();
auto output = at::empty({batch_size, out_features}, opts); auto output = at::empty({batch_size, out_features}, opts);
at::Tensor gelu_in; at::Tensor pre_act;
if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); } // If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8},
auto lt_workspace = at::empty({1 << 22}, opts); is_gelu ? opts : opts.dtype(torch::kUInt8)); }
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] { DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] {
auto result = linear_gelu_forward_cuda<scalar_t>( auto result = linear_act_forward_cuda<scalar_t>(
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr, bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,
in_features, in_features,
batch_size, batch_size,
out_features, out_features,
is_gelu,
heuristic, heuristic,
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr, save_pre_act ? pre_act.data_ptr() : nullptr);
(void*) (lt_workspace.data_ptr<scalar_t>())); TORCH_CHECK(result == 0, "linear_act_forward failed.");
TORCH_CHECK(result == 0, "linear_gelu_forward failed.");
}); });
std::vector<at::Tensor> result = {output}; std::vector<at::Tensor> result = {output};
if (save_gelu_in) { result.push_back(gelu_in); }; if (save_pre_act) { result.push_back(pre_act); };
return result; return result;
} }
std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad( std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
at::Tensor weight, at::Tensor d_output, at::Tensor gelu_in, int heuristic at::Tensor weight, at::Tensor d_output, at::Tensor pre_act, bool is_gelu, int heuristic
) { ) {
int batch_size = d_output.size(0); int64_t batch_size = d_output.size(0);
int out_features = d_output.size(1); int64_t out_features = d_output.size(1);
int in_features = weight.size(1); int64_t in_features = weight.size(1);
TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16); TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);
TORCH_CHECK(weight.dtype() == d_output.dtype()); TORCH_CHECK(weight.dtype() == d_output.dtype());
TORCH_CHECK(weight.dtype() == gelu_in.dtype()); TORCH_CHECK(is_gelu ? (pre_act.dtype() == weight.dtype()) : (pre_act.dtype() == torch::kUInt8));
TORCH_CHECK(weight.is_cuda()); TORCH_CHECK(weight.is_cuda());
TORCH_CHECK(d_output.is_cuda()); TORCH_CHECK(d_output.is_cuda());
TORCH_CHECK(gelu_in.is_cuda()); TORCH_CHECK(pre_act.is_cuda());
TORCH_CHECK(weight.is_contiguous()); TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(d_output.is_contiguous()); TORCH_CHECK(d_output.is_contiguous());
TORCH_CHECK(gelu_in.is_contiguous()); TORCH_CHECK(pre_act.is_contiguous());
CHECK_SHAPE(weight, out_features, in_features); CHECK_SHAPE(weight, out_features, in_features);
CHECK_SHAPE(d_output, batch_size, out_features); CHECK_SHAPE(d_output, batch_size, out_features);
CHECK_SHAPE(gelu_in, batch_size, in_features); // If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
CHECK_SHAPE(pre_act, batch_size, is_gelu ? in_features : in_features / 8);
// Otherwise the kernel will be launched from cuda:0 device // Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing // Cast to char to avoid compiler warning about narrowing
...@@ -170,22 +168,20 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad( ...@@ -170,22 +168,20 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
auto opts = weight.options(); auto opts = weight.options();
auto d_bias = at::empty({in_features}, opts); auto d_bias = at::empty({in_features}, opts);
auto d_input = at::empty({batch_size, in_features}, opts); auto d_input = at::empty({batch_size, in_features}, opts);
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_gelu_linear_dgrad_bgrad", [&] { DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] {
auto result = bias_gelu_linear_dgrad_bgrad_cuda<scalar_t>( auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>(
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
d_output.data_ptr<scalar_t>(), d_output.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(), pre_act.data_ptr(),
in_features, in_features,
batch_size, batch_size,
out_features, out_features,
is_gelu,
heuristic, heuristic,
d_input.data_ptr<scalar_t>(), d_input.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(), d_bias.data_ptr<scalar_t>());
(void*) (lt_workspace.data_ptr<scalar_t>())); TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed.");
TORCH_CHECK(result == 0, "bias_gelu_linear_dgrad_bgrad failed.");
}); });
return {d_input, d_bias}; return {d_input, d_bias};
...@@ -193,6 +189,6 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad( ...@@ -193,6 +189,6 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad"); m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward"); m.def("linear_act_forward", &linear_act_forward, "linear gelu/relu forward");
m.def("bias_gelu_linear_dgrad_bgrad", &bias_gelu_linear_dgrad_bgrad, "bias gelu linear dgrad bgrad"); m.def("bias_act_linear_dgrad_bgrad", &bias_act_linear_dgrad_bgrad, "bias gelu/relu linear dgrad bgrad");
} }
This diff is collapsed.
...@@ -23,7 +23,7 @@ from transformers.models.bert.modeling_bert import BertForPreTrainingOutput ...@@ -23,7 +23,7 @@ from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
from einops import rearrange from einops import rearrange
from flash_attn.modules.mha import MHA from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.bert_padding import unpad_input, pad_input
...@@ -61,24 +61,24 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False): ...@@ -61,24 +61,24 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
def create_mlp_cls(config, layer_idx=None, return_residual=False): def create_mlp_cls(config, layer_idx=None, return_residual=False):
inner_dim = config.intermediate_size inner_dim = config.intermediate_size
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False) fused_mlp = getattr(config, 'fused_mlp', False)
if fused_dense_gelu_dense: if fused_mlp:
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only ' assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only '
'supports approximate gelu') 'supports approximate gelu')
if not fused_dense_gelu_dense: if not fused_mlp:
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none' approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
mlp_cls = partial(Mlp, hidden_features=inner_dim, mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate=approximate), activation=partial(F.gelu, approximate=approximate),
return_residual=return_residual) return_residual=return_residual)
else: else:
if FusedDenseGeluDense is None: if FusedMLP is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence): if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim, mlp_cls = partial(FusedMLP, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual) checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
return mlp_cls return mlp_cls
......
...@@ -17,7 +17,7 @@ from transformers import GPT2Config ...@@ -17,7 +17,7 @@ from transformers import GPT2Config
from einops import rearrange from einops import rearrange
from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
...@@ -77,22 +77,22 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt ...@@ -77,22 +77,22 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False) fused_mlp = getattr(config, 'fused_mlp', False)
if fused_dense_gelu_dense: if fused_mlp:
assert config.activation_function in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only ' assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu']
'supports approximate gelu')
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False) fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
if fused_dense_sqrelu_dense: if fused_dense_sqrelu_dense:
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only ' assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
'supports approximate activation_function sqrelu') 'supports approximate activation_function sqrelu')
assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense) assert not (fused_dense_sqrelu_dense and fused_mlp)
if process_group is not None: if process_group is not None:
assert fused_dense_gelu_dense, 'Tensor Parallel is only implemented for FusedDenseGeluDense' assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense: if not fused_mlp and not fused_dense_sqrelu_dense:
if config.activation_function == 'relu': if config.activation_function == 'relu':
activation = partial(F.relu, inplace=True) activation = partial(F.relu, inplace=True)
else: else:
approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none' approximate = ('tanh' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
activation=partial(F.gelu, approximate=approximate) activation=partial(F.gelu, approximate=approximate)
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs) mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
else: else:
...@@ -101,14 +101,17 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp ...@@ -101,14 +101,17 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if isinstance(mlp_checkpoint_lvl, Sequence): if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
if fused_dense_gelu_dense: if fused_mlp:
if FusedDenseGeluDense is None: if FusedMLP is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense activation = ('gelu_approx' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu')
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
parallel_kwargs = ({'process_group': process_group, parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)} 'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {}) if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl, mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation,
checkpoint_lvl=mlp_checkpoint_lvl,
**parallel_kwargs, **factory_kwargs) **parallel_kwargs, **factory_kwargs)
elif fused_dense_sqrelu_dense: elif fused_dense_sqrelu_dense:
assert FusedDenseSqreluDense is not None assert FusedDenseSqreluDense is not None
...@@ -210,7 +213,8 @@ class GPTModel(GPTPreTrainedModel): ...@@ -210,7 +213,8 @@ class GPTModel(GPTPreTrainedModel):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = getattr(config, 'sequence_parallel', True) self.sequence_parallel = getattr(config, 'sequence_parallel', True)
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'relu', 'sqrelu'] assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
'relu', 'sqrelu']
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple) * pad_vocab_size_multiple)
......
...@@ -20,7 +20,7 @@ from timm.models.helpers import named_apply ...@@ -20,7 +20,7 @@ from timm.models.helpers import named_apply
from flash_attn.layers.patch_embed import PatchEmbed from flash_attn.layers.patch_embed import PatchEmbed
from flash_attn.modules.mha import MHA from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
try: try:
...@@ -37,22 +37,22 @@ def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_ ...@@ -37,22 +37,22 @@ def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_
return mixer_cls return mixer_cls
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense): def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
inner_dim = int(embed_dim * mlp_ratio) inner_dim = int(embed_dim * mlp_ratio)
if not fused_dense_gelu_dense: if not fused_mlp:
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer()) mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
else: else:
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim) mlp_cls = partial(FusedMLP, hidden_features=inner_dim)
return mlp_cls return mlp_cls
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc, drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc,
fused_dense_gelu_dense, fused_dropout_add_ln, layer_idx=None, n_layer=None, fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None,
last_layer_subset=False): last_layer_subset=False):
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc, mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
cross_attn=(last_layer_subset and layer_idx == n_layer - 1)) cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense) mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer, block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate, prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate,
...@@ -92,7 +92,7 @@ class VisionTransformer(nn.Module): ...@@ -92,7 +92,7 @@ class VisionTransformer(nn.Module):
act_layer=None, act_layer=None,
use_flash_attn=False, use_flash_attn=False,
fused_bias_fc=False, fused_bias_fc=False,
fused_dense_gelu_dense=False, fused_mlp=False,
fused_dropout_add_ln=False, fused_dropout_add_ln=False,
): ):
""" """
...@@ -164,7 +164,7 @@ class VisionTransformer(nn.Module): ...@@ -164,7 +164,7 @@ class VisionTransformer(nn.Module):
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i], drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i],
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn, norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense, fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp,
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth, fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
last_layer_subset=(global_pool == 'token') last_layer_subset=(global_pool == 'token')
) for i in range(depth)]) ) for i in range(depth)])
......
...@@ -121,6 +121,7 @@ class Block(nn.Module): ...@@ -121,6 +121,7 @@ class Block(nn.Module):
) )
if mixer_kwargs is None: if mixer_kwargs is None:
mixer_kwargs = {} mixer_kwargs = {}
if mixer_subset is not None:
mixer_kwargs['mixer_subset'] = mixer_subset mixer_kwargs['mixer_subset'] = mixer_subset
hidden_states = self.mixer(hidden_states, **mixer_kwargs) hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None: if mixer_subset is not None:
......
...@@ -5,9 +5,9 @@ import torch.nn as nn ...@@ -5,9 +5,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
try: try:
from flash_attn.ops.fused_dense import FusedDenseGeluDense, ParallelFusedDenseGeluDense from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
except ImportError: except ImportError:
FusedDenseGeluDense, ParallelFusedDenseGeluDense = None, None FusedMLP, ParallelFusedMLP = None, None
class Mlp(nn.Module): class Mlp(nn.Module):
......
# Copyright (c) 2022, Tri Dao. # Copyright (c) 2023, Tri Dao.
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py # Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16. # We make it work with pytorch amp and with bfloat16.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from typing import Optional from typing import Optional
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -19,6 +20,11 @@ from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all ...@@ -19,6 +20,11 @@ from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all
from flash_attn.utils.distributed import reduce_scatter, all_reduce from flash_attn.utils.distributed import reduce_scatter, all_reduce
@torch.jit.script
def relu_bwd(g, x):
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
class FusedDenseFunc(torch.autograd.Function): class FusedDenseFunc(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -185,12 +191,13 @@ class RowParallelLinear(nn.Linear): ...@@ -185,12 +191,13 @@ class RowParallelLinear(nn.Linear):
return reduce_fn(out, self.process_group) return reduce_fn(out, self.process_group)
class FusedDenseGeluDenseFunc(torch.autograd.Function): class FusedMLPFunc(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd @custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False, def forward(ctx, x, weight1, bias1, weight2, bias2, activation='gelu_approx', save_pre_act=True,
checkpoint_lvl=0, heuristic=0, process_group=None, sequence_parallel=True): return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None,
sequence_parallel=True):
""" """
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather of x before doing the matmul. with sequence parallelism: we do an all_gather of x before doing the matmul.
...@@ -198,10 +205,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -198,10 +205,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
checkpoint_lvl: checkpoint_lvl:
0: no recomputation in the bwd 0: no recomputation in the bwd
1: recompute gelu_out in the bwd 1: recompute gelu_out / relu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd 2: recompute pre_act and gelu_out / relu_out in the bwd
""" """
assert -1 <= heuristic <= 4 assert -1 <= heuristic <= 4
assert activation in ['gelu_approx', 'relu']
if not save_pre_act: if not save_pre_act:
checkpoint_lvl = 2 checkpoint_lvl = 2
assert checkpoint_lvl in [0, 1, 2] assert checkpoint_lvl in [0, 1, 2]
...@@ -209,6 +217,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -209,6 +217,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
ctx.process_group = process_group ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.checkpoint_lvl = checkpoint_lvl ctx.checkpoint_lvl = checkpoint_lvl
ctx.activation = activation
ctx.heuristic = heuristic ctx.heuristic = heuristic
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
...@@ -237,23 +246,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -237,23 +246,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M') raise RuntimeError('fused_dense only supports matrix dims <= 2M')
if heuristic == -1: if heuristic == -1:
gelu_in = F.linear(total_x, weight1, bias1) pre_act = F.linear(total_x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh') activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else F.relu)
output1 = activation_fn(pre_act)
# This is before adding bias1 # This is before adding bias1
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1) # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
# with torch.jit.fuser('fuser2'): # with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(gelu_in, bias1) # output1 = bias_gelu(pre_act, bias1)
else: else:
output1, *rest = fused_dense_cuda.linear_gelu_forward( is_gelu = activation == 'gelu_approx'
total_x.reshape(batch_dim, n), weight1, bias1, save_pre_act, heuristic output1, *rest = fused_dense_cuda.linear_act_forward(
total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic
) )
if save_pre_act: if save_pre_act:
gelu_in = rest[0] pre_act = rest[0]
output2 = F.linear(output1, weight2, bias2) output2 = F.linear(output1, weight2, bias2)
if checkpoint_lvl == 0: if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1) # For RELU the pre_act is very small (just a bit-mask) so we just save it
ctx.save_for_backward(x, weight1, weight2, pre_act, output1)
elif checkpoint_lvl == 1: elif checkpoint_lvl == 1:
ctx.save_for_backward(x, weight1, weight2, gelu_in) ctx.save_for_backward(x, weight1, weight2, pre_act)
elif checkpoint_lvl == 2: elif checkpoint_lvl == 2:
ctx.save_for_backward(x, weight1, weight2, bias1) ctx.save_for_backward(x, weight1, weight2, bias1)
output2 = output2.reshape(*batch_shape, output2.shape[-1]) output2 = output2.reshape(*batch_shape, output2.shape[-1])
...@@ -264,6 +277,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -264,6 +277,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
def backward(ctx, grad_output, *args): def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl checkpoint_lvl = ctx.checkpoint_lvl
activation = ctx.activation
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else F.relu)
if ctx.return_residual: if ctx.return_residual:
grad_input, = args grad_input, = args
grad_input = grad_input.contiguous() grad_input = grad_input.contiguous()
...@@ -277,27 +293,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -277,27 +293,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if checkpoint_lvl in [0, 1]: if checkpoint_lvl in [0, 1]:
if process_group is not None and sequence_parallel: if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True) total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
if checkpoint_lvl == 0: if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
gelu_in, output1 = rest pre_act, output1 = rest
elif checkpoint_lvl == 1: elif checkpoint_lvl == 1:
gelu_in, = rest pre_act, = rest
output1 = F.gelu(gelu_in, approximate='tanh') output1 = activation_fn(pre_act)
elif checkpoint_lvl == 2: elif checkpoint_lvl == 2:
bias1, = rest bias1, = rest
if process_group is not None and sequence_parallel: if process_group is not None and sequence_parallel:
total_x, _ = all_gather_raw(x, process_group) total_x, _ = all_gather_raw(x, process_group)
if ctx.heuristic == -1: if ctx.heuristic == -1:
gelu_in = F.linear(total_x, weight1, bias1) pre_act = F.linear(total_x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh') output1 = activation_fn(pre_act)
else: else:
output1, gelu_in = fused_dense_cuda.linear_gelu_forward( output1, pre_act = fused_dense_cuda.linear_act_forward(
total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, True, total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1,
ctx.heuristic activation == 'gelu_approx', True, ctx.heuristic
) )
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
output1 = output1.reshape(batch_dim, output1.shape[-1]) output1 = output1.reshape(batch_dim, output1.shape[-1])
gelu_in = gelu_in.reshape(batch_dim, gelu_in.shape[-1]) pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])
if ctx.needs_input_grad[3]: if ctx.needs_input_grad[3]:
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad( grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
output1, grad_output, ctx.needs_input_grad[4] output1, grad_output, ctx.needs_input_grad[4]
...@@ -306,24 +322,25 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -306,24 +322,25 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_weight2 = None grad_weight2 = None
grad_bias2 = grad_output if ctx.needs_input_grad[4] else None grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
if ctx.heuristic == -1: if ctx.heuristic == -1:
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in) # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
grad_output1 = F.linear(grad_output, weight2.t()) grad_output1 = F.linear(grad_output, weight2.t())
with torch.jit.fuser('fuser2'): with torch.jit.fuser('fuser2'):
grad_gelu = gelu_bwd(grad_output1, gelu_in) activation_grad_fn = gelu_bwd if activation == 'gelu_approx' else relu_bwd
grad_pre_act = activation_grad_fn(grad_output1, pre_act)
else: else:
# The cublasLt epilogue has to compute both gelu grad and bias grad, we can't # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
# just compute gelu grad # just compute gelu/relu grad
grad_gelu, grad_bias1 = fused_dense_cuda.bias_gelu_linear_dgrad_bgrad( grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(
weight2, grad_output, gelu_in, ctx.heuristic weight2, grad_output, pre_act, activation == 'gelu_approx', ctx.heuristic
) )
if not ctx.needs_input_grad[2]: if not ctx.needs_input_grad[2]:
grad_bias1 = None grad_bias1 = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if not ctx.return_residual: if not ctx.return_residual:
grad_input = F.linear(grad_gelu, weight1.t()) grad_input = F.linear(grad_pre_act, weight1.t())
else: else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
grad_gelu, weight1) grad_pre_act, weight1)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None: if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
...@@ -335,55 +352,60 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -335,55 +352,60 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if process_group is not None and sequence_parallel: if process_group is not None and sequence_parallel:
handle_x.wait() handle_x.wait()
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu, total_x.reshape(batch_dim, total_x.shape[-1]), grad_pre_act,
ctx.needs_input_grad[2] ctx.needs_input_grad[2]
) )
else: else:
grad_weight1 = None grad_weight1 = None
grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
else: else:
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
if process_group is not None and sequence_parallel: if process_group is not None and sequence_parallel:
handle_x.wait() handle_x.wait()
grad_weight1 = F.linear(grad_gelu.t(), grad_weight1 = F.linear(grad_pre_act.t(),
total_x.reshape(batch_dim, total_x.shape[-1]).t()) total_x.reshape(batch_dim, total_x.shape[-1]).t())
else: else:
grad_weight1 = None grad_weight1 = None
if process_group is not None and ctx.needs_input_grad[0]: if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait() handle_grad_input.wait()
return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
None, None, None, None, None, None) None, None, None, None, None, None, None)
def fused_dense_gelu_dense_func( def fused_mlp_func(
x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None, x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
bias2: Optional[Tensor] = None, bias2: Optional[Tensor] = None, activation: str = 'gelu_approx',
save_pre_act: bool = True, return_residual: bool = False, save_pre_act: bool = True, return_residual: bool = False,
checkpoint_lvl: int = 0, heuristic: int = 0, checkpoint_lvl: int = 0, heuristic: int = 0,
process_group: Optional[ProcessGroup] = None, process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True sequence_parallel: bool = True
): ):
assert activation in ['gelu_approx', 'relu']
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled())) or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == 'relu' else 8) == 0)
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda) if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
and (bias2 is None or bias2.is_cuda) and dtype_eligible): and (bias2 is None or bias2.is_cuda) and dtype_eligible and dim_eligible):
return FusedDenseGeluDenseFunc.apply( return FusedMLPFunc.apply(
x, weight1, bias1, weight2, bias2, save_pre_act, return_residual, x, weight1, bias1, weight2, bias2, activation, save_pre_act, return_residual,
checkpoint_lvl, heuristic, process_group, sequence_parallel checkpoint_lvl, heuristic, process_group, sequence_parallel
) )
else: else:
assert process_group is None assert process_group is None
gelu_in = F.linear(x, weight1, bias1) pre_act = F.linear(x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh') activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else partial(F.relu, inplace=True))
output1 = activation_fn(pre_act)
output2 = F.linear(output1, weight2, bias2) output2 = F.linear(output1, weight2, bias2)
return output2 if not return_residual else (output2, x) return output2 if not return_residual else (output2, x)
class FusedDenseGeluDense(nn.Module): class FusedMLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features=None, bias1=True, def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0, bias2=True, activation='gelu_approx', return_residual=False,
device=None, dtype=None): checkpoint_lvl=0, heuristic='auto', device=None, dtype=None):
""" """
If process_group is not None, we're doing Tensor Parallel with sequence parallelism: If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul. we do an all_gather of x before doing the matmul, gelu, then matmul.
...@@ -392,21 +414,24 @@ class FusedDenseGeluDense(nn.Module): ...@@ -392,21 +414,24 @@ class FusedDenseGeluDense(nn.Module):
checkpoint_lvl (increasing lvl means slower but more memory saving): checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd 0: no recomputation in the bwd
1: recompute gelu_out in the bwd 1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd 2: recompute pre_act and gelu_out in the bwd
heuristic: heuristic:
-1: don't fuse gemm + gelu (separate kernel) -1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu 0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf. 'auto': heuristic will be picked automatically:
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16. For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection. to fuse the backward of nn.Linear with the residual connection.
""" """
assert checkpoint_lvl in [0, 1, 2] assert checkpoint_lvl in [0, 1, 2]
assert activation in ['gelu_approx', 'relu']
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
if out_features is None: if out_features is None:
out_features = in_features out_features = in_features
self.activation = activation
self.return_residual = return_residual self.return_residual = return_residual
self.checkpoint_lvl = checkpoint_lvl self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic self.heuristic = heuristic
...@@ -414,11 +439,20 @@ class FusedDenseGeluDense(nn.Module): ...@@ -414,11 +439,20 @@ class FusedDenseGeluDense(nn.Module):
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
def forward(self, x, process_group=None): def forward(self, x, process_group=None):
out = fused_dense_gelu_dense_func( dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
if self.heuristic == 'auto':
if self.activation == 'gelu_approx':
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
else:
heuristic = 0
else:
heuristic = self.heuristic
out = fused_mlp_func(
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
save_pre_act=self.training, return_residual=self.return_residual, activation=self.activation, save_pre_act=self.training,
checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic, return_residual=self.return_residual, checkpoint_lvl=self.checkpoint_lvl,
process_group=process_group heuristic=heuristic, process_group=process_group
) )
if self.return_residual: if self.return_residual:
out, x = out out, x = out
...@@ -427,11 +461,12 @@ class FusedDenseGeluDense(nn.Module): ...@@ -427,11 +461,12 @@ class FusedDenseGeluDense(nn.Module):
return out if not self.return_residual else (out, x) return out if not self.return_residual else (out, x)
class ParallelFusedDenseGeluDense(nn.Module): class ParallelFusedMLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features=None, def __init__(self, in_features, hidden_features, out_features=None, activation='gelu_approx',
process_group: ProcessGroup = None, bias1=True, bias2=True, process_group: ProcessGroup = None, bias1=True, bias2=True,
sequence_parallel=True, checkpoint_lvl=0, heuristic=0, device=None, dtype=None): sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
device=None, dtype=None):
""" """
process_group is required. We're doing Tensor Parallel with sequence parallelism: process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul. we do an all_gather of x before doing the matmul, gelu, then matmul.
...@@ -440,19 +475,22 @@ class ParallelFusedDenseGeluDense(nn.Module): ...@@ -440,19 +475,22 @@ class ParallelFusedDenseGeluDense(nn.Module):
checkpoint_lvl (increasing lvl means slower but more memory saving): checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd 0: no recomputation in the bwd
1: recompute gelu_out in the bwd 1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd 2: recompute pre_act and gelu_out in the bwd
heuristic: heuristic:
-1: don't fuse gemm + gelu (separate kernel) -1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu 0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf. 'auto': heuristic will be picked automatically:
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16. For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
""" """
assert checkpoint_lvl in [0, 1, 2] assert checkpoint_lvl in [0, 1, 2]
assert activation in ['gelu_approx', 'relu']
assert process_group is not None assert process_group is not None
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
if out_features is None: if out_features is None:
out_features = in_features out_features = in_features
self.activation = activation
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
self.checkpoint_lvl = checkpoint_lvl self.checkpoint_lvl = checkpoint_lvl
...@@ -463,10 +501,19 @@ class ParallelFusedDenseGeluDense(nn.Module): ...@@ -463,10 +501,19 @@ class ParallelFusedDenseGeluDense(nn.Module):
bias=bias2, **factory_kwargs) bias=bias2, **factory_kwargs)
def forward(self, x): def forward(self, x):
out = fused_dense_gelu_dense_func( dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
if self.heuristic == 'auto':
if self.activation == 'gelu_approx':
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
else:
heuristic = 0
else:
heuristic = self.heuristic
out = fused_mlp_func(
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl, activation=self.activation, save_pre_act=self.training,
heuristic=self.heuristic, checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic,
process_group=self.process_group, process_group=self.process_group,
sequence_parallel=self.sequence_parallel sequence_parallel=self.sequence_parallel
) )
......
...@@ -95,13 +95,13 @@ def test_bert_optimized(model_name): ...@@ -95,13 +95,13 @@ def test_bert_optimized(model_name):
""" """
dtype = torch.float16 dtype = torch.float16
config = BertConfig.from_pretrained(model_name) config = BertConfig.from_pretrained(model_name)
# Our implementation of fused_dense_gelu_dense assumes the activation is # Our implementation of fused_mlp assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast". # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_dense_gelu_dense. # If you just want "gelu", disable fused_mlp.
config.hidden_act = "gelu_new" config.hidden_act = "gelu_new"
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_dense_gelu_dense = True config.fused_mlp = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
model = BertForPreTraining.from_pretrained(model_name, config) model = BertForPreTraining.from_pretrained(model_name, config)
...@@ -171,13 +171,13 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs ...@@ -171,13 +171,13 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
""" """
dtype = torch.float16 dtype = torch.float16
config = BertConfig.from_pretrained(model_name) config = BertConfig.from_pretrained(model_name)
# Our implementation of fused_dense_gelu_dense assumes the activation is # Our implementation of fused_mlp assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast". # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_dense_gelu_dense. # If you just want "gelu", disable fused_mlp.
config.hidden_act = "gelu_new" config.hidden_act = "gelu_new"
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_dense_gelu_dense = True config.fused_mlp = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
config.dense_seq_output = True config.dense_seq_output = True
config.last_layer_subset = last_layer_subset config.last_layer_subset = last_layer_subset
......
...@@ -82,7 +82,7 @@ def test_gpt2_optimized(model_name): ...@@ -82,7 +82,7 @@ def test_gpt2_optimized(model_name):
vocab_size_og = config.vocab_size vocab_size_og = config.vocab_size
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_dense_gelu_dense = True config.fused_mlp = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
config.residual_in_fp32 = True config.residual_in_fp32 = True
config.pad_vocab_size_multiple = 8 config.pad_vocab_size_multiple = 8
......
...@@ -18,7 +18,7 @@ from flash_attn.utils.distributed import all_gather_raw ...@@ -18,7 +18,7 @@ from flash_attn.utils.distributed import all_gather_raw
@pytest.mark.parametrize('fused_ft_kernel', [False, True]) @pytest.mark.parametrize('fused_ft_kernel', [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True]) # @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize('optimized', [False, True]) @pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [True]) # @pytest.mark.parametrize('optimized', [False])
@pytest.mark.parametrize('rotary', [False, True]) @pytest.mark.parametrize('rotary', [False, True])
# @pytest.mark.parametrize('rotary', [False]) # @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize('model_name', ["gpt2"])
...@@ -34,10 +34,11 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): ...@@ -34,10 +34,11 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
if rotary: if rotary:
config.n_positions = 0 config.n_positions = 0
config.rotary_emb_dim = 64 config.rotary_emb_dim = 64
config.residual_in_fp32 = True
if optimized: if optimized:
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_dense_gelu_dense = True config.fused_mlp = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
# if not rotary, we load the weight from HF but ignore the position embeddings. # if not rotary, we load the weight from HF but ignore the position embeddings.
...@@ -78,6 +79,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): ...@@ -78,6 +79,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
fused_ft_kernel=fused_ft_kernel, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True) return_dict_in_generate=True, output_scores=True, timing=True)
print(out.sequences) print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel: if fused_ft_kernel:
out_cg = model.generate(input_ids=input_ids, max_length=max_length, out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel, cg=True, fused_ft_kernel=fused_ft_kernel, cg=True,
...@@ -94,122 +96,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): ...@@ -94,122 +96,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation.py -k "parallel"
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('fused_ft_kernel', [True])
# @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if rotary:
config.n_positions = 0
config.rotary_emb_dim = 64
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
from apex.transformer import parallel_state
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
dtype=dtype, process_group=process_group,
world_size=world_size, rank=rank)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ",
return_tensors="pt").input_ids.to(device=device)
max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print(sequences)
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out.sequences)
if fused_ft_kernel:
out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out_cg.sequences)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
assert torch.all(out.sequences == sequences) assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
......
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation_parallel.py -k "parallel"
import os
import re
import torch
import pytest
from einops import rearrange
from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('fused_ft_kernel', [True])
# @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if rotary:
config.n_positions = 0
config.rotary_emb_dim = 64
config.residual_in_fp32 = True
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
from apex.transformer import parallel_state
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
dtype=dtype, process_group=process_group,
world_size=world_size, rank=rank)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ",
return_tensors="pt").input_ids.to(device=device)
max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print(sequences)
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out.sequences)
if fused_ft_kernel:
out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out_cg.sequences)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
# Run test with: # Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -59,10 +61,12 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -59,10 +61,12 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
n_positions=seqlen if has_pos_emb else 0, n_positions=seqlen if has_pos_emb else 0,
vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True, scale_attn_by_inverse_layer_idx=True, use_flash_attn=True,
fused_dense_gelu_dense=True, fused_bias_fc=True, fused_dropout_add_ln=True, fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True,
residual_in_fp32=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5, rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size, pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel) sequence_parallel=sequence_parallel)
config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
model_pt = GPTLMHeadModel(config, device=device) model_pt = GPTLMHeadModel(config, device=device)
def init_layer_norm(module): def init_layer_norm(module):
...@@ -131,9 +135,9 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): ...@@ -131,9 +135,9 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
grad_dict['transformer.embeddings.position_embeddings.weight'], grad_dict['transformer.embeddings.position_embeddings.weight'],
rtol=rtol, atol=atol rtol=rtol, atol=atol
) )
assert torch.allclose(model.transformer.ln_0.weight.grad, grad_dict['transformer.ln_0.weight'], assert torch.allclose(model.transformer.ln_f.weight.grad, grad_dict['transformer.ln_f.weight'],
rtol=rtol, atol=atol) rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.ln_0.bias.grad, grad_dict['transformer.ln_0.bias'], assert torch.allclose(model.transformer.ln_f.bias.grad, grad_dict['transformer.ln_f.bias'],
rtol=rtol, atol=atol) rtol=rtol, atol=atol)
for i in range(num_layers): for i in range(num_layers):
assert torch.allclose( assert torch.allclose(
......
...@@ -8,11 +8,11 @@ from timm.models.vision_transformer import vit_base_patch16_224 ...@@ -8,11 +8,11 @@ from timm.models.vision_transformer import vit_base_patch16_224
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224 from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True]) @pytest.mark.parametrize('fused_mlp', [False, True])
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False]) # @pytest.mark.parametrize('fused_mlp', [False])
@pytest.mark.parametrize('optimized', [False, True]) @pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [True]) # @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_dense_gelu_dense): def test_vit(optimized, fused_mlp):
"""Check that our implementation of ViT matches the timm's implementation: """Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32. timm' forward pass in fp16, when compared to timm's forward pass in fp32.
...@@ -23,7 +23,7 @@ def test_vit(optimized, fused_dense_gelu_dense): ...@@ -23,7 +23,7 @@ def test_vit(optimized, fused_dense_gelu_dense):
kwargs = {} kwargs = {}
if optimized: if optimized:
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True) kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense kwargs['fused_mlp'] = fused_mlp
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype) model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
model_ref = vit_base_patch16_224(pretrained=True).to(device=device) model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
...@@ -46,4 +46,5 @@ def test_vit(optimized, fused_dense_gelu_dense): ...@@ -46,4 +46,5 @@ def test_vit(optimized, fused_dense_gelu_dense):
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}') print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}') print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item() rtol = 2 if not fused_mlp else 4
assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()
...@@ -15,7 +15,7 @@ from apex.transformer import parallel_state ...@@ -15,7 +15,7 @@ from apex.transformer import parallel_state
from apex.transformer import tensor_parallel from apex.transformer import tensor_parallel
from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import FusedDenseGeluDense, ParallelFusedDenseGeluDense from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
...@@ -27,7 +27,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 ...@@ -27,7 +27,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False]) # @pytest.mark.parametrize('sequence_parallel', [True])
@pytest.mark.parametrize('dim', [1024]) @pytest.mark.parametrize('dim', [1024])
def test_block_parallel(dim, sequence_parallel, world_size, dtype): def test_block_parallel(dim, sequence_parallel, world_size, dtype):
head_dim = 64 head_dim = 64
...@@ -62,8 +62,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -62,8 +62,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2), mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype) use_flash_attn=True, device=device, dtype=dtype)
mlp_cls_pt = partial(FusedDenseGeluDense, hidden_features=4 * dim, mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)
device=device, dtype=dtype)
norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype) norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)
model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True) model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)
with torch.no_grad(): with torch.no_grad():
...@@ -76,7 +75,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -76,7 +75,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
process_group=parallel_state.get_tensor_model_parallel_group(), process_group=parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
sequence_parallel=sequence_parallel, device=device, dtype=dtype) sequence_parallel=sequence_parallel, device=device, dtype=dtype)
mlp_cls = partial(ParallelFusedDenseGeluDense, hidden_features=4 * dim, mlp_cls = partial(ParallelFusedMLP, hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(), process_group=parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel, device=device, dtype=dtype) sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True, model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True,
...@@ -143,7 +142,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype): ...@@ -143,7 +142,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
x.grad, x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim] x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad, if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small rtol=rtol, atol=atol / 10 # magnitude of x.grad is quite small
) )
assert torch.allclose( assert torch.allclose(
residual.grad, residual.grad,
......
import math import math
from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -6,7 +7,7 @@ import pytest ...@@ -6,7 +7,7 @@ import pytest
from einops import rearrange from einops import rearrange
from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense from flash_attn.ops.fused_dense import FusedDense, FusedMLP
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
...@@ -60,14 +61,24 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, ...@@ -60,14 +61,24 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('heuristic', [0, -1]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('heuristic', ['auto', -1])
# @pytest.mark.parametrize('heuristic', ['auto'])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2]) @pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
# @pytest.mark.parametrize('checkpoint_lvl', [1])
@pytest.mark.parametrize('return_residual', [False, True]) @pytest.mark.parametrize('return_residual', [False, True])
# @pytest.mark.parametrize('return_residual', [False])
@pytest.mark.parametrize('has_bias2', [True, False]) @pytest.mark.parametrize('has_bias2', [True, False])
@pytest.mark.parametrize('has_bias1', [True, False]) @pytest.mark.parametrize('has_bias1', [True, False])
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias1', [True])
@pytest.mark.parametrize('activation', ['gelu_approx', 'relu'])
# @pytest.mark.parametrize('activation', ['relu'])
@pytest.mark.parametrize('out_features', [1024, 4096]) @pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096]) @pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, return_residual, # @pytest.mark.parametrize('out_features', [4096])
# @pytest.mark.parametrize('in_features', [1024])
def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2, return_residual,
checkpoint_lvl, heuristic, dtype): checkpoint_lvl, heuristic, dtype):
device = 'cuda' device = 'cuda'
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
...@@ -82,8 +93,8 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, ...@@ -82,8 +93,8 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
dtype=dtype) dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device, model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
dtype=dtype) dtype=dtype)
model = FusedDenseGeluDense(in_features, out_features, in_features, bias1=has_bias1, model = FusedMLP(in_features, out_features, in_features, activation=activation,
bias2=has_bias2, return_residual=return_residual, bias1=has_bias1, bias2=has_bias2, return_residual=return_residual,
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic, checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
device=device, dtype=dtype) device=device, dtype=dtype)
with torch.no_grad(): with torch.no_grad():
...@@ -93,7 +104,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, ...@@ -93,7 +104,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
model.fc2.weight.copy_(model_pt_fc2.weight) model.fc2.weight.copy_(model_pt_fc2.weight)
if has_bias2: if has_bias2:
model.fc2.bias.copy_(model_pt_fc2.bias) model.fc2.bias.copy_(model_pt_fc2.bias)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else partial(F.relu, inplace=True))
out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt)))
if not return_residual: if not return_residual:
out = model(x) out = model(x)
else: else:
...@@ -107,6 +120,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, ...@@ -107,6 +120,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
g = torch.randn_like(out) / 32 g = torch.randn_like(out) / 32
out_pt.backward(g) out_pt.backward(g)
out.backward(g) out.backward(g)
# The error for relu is higher still
if activation == 'relu':
atol = 1e-1 if dtype == torch.bfloat16 else 5e-2
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10) assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
......
...@@ -10,8 +10,8 @@ import pytest ...@@ -10,8 +10,8 @@ import pytest
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer import tensor_parallel from apex.transformer import tensor_parallel
from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense from flash_attn.ops.fused_dense import FusedDense, FusedMLP
from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedDenseGeluDense from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedMLP
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...@@ -106,8 +106,7 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle ...@@ -106,8 +106,7 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle
# @pytest.mark.parametrize('has_bias2', [True]) # @pytest.mark.parametrize('has_bias2', [True])
@pytest.mark.parametrize('out_features', [4096]) @pytest.mark.parametrize('out_features', [4096])
@pytest.mark.parametrize('in_features', [1024]) @pytest.mark.parametrize('in_features', [1024])
def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_parallel, def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype):
world_size, dtype):
assert out_features % world_size == 0 assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
...@@ -137,7 +136,7 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_p ...@@ -137,7 +136,7 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_p
dtype=dtype) dtype=dtype)
partition_out_features = out_features // world_size partition_out_features = out_features // world_size
partition_in_features = in_features // world_size partition_in_features = in_features // world_size
model = ParallelFusedDenseGeluDense(in_features, out_features, in_features, model = ParallelFusedMLP(in_features, out_features, in_features,
process_group=parallel_state.get_tensor_model_parallel_group(), process_group=parallel_state.get_tensor_model_parallel_group(),
bias2=has_bias2 and rank == 0, bias2=has_bias2 and rank == 0,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
......
...@@ -48,7 +48,7 @@ config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim, ...@@ -48,7 +48,7 @@ config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
n_layer=n_layer, n_head=nheads, n_layer=n_layer, n_head=nheads,
scale_attn_by_inverse_layer_idx=True, scale_attn_by_inverse_layer_idx=True,
rotary_emb_fraction=rotary_emb_fraction, rotary_emb_fraction=rotary_emb_fraction,
use_flash_attn=True, fused_dense_gelu_dense=True, use_flash_attn=True, fused_mlp=True,
fused_bias_fc=True, fused_dropout_add_ln=True, fused_bias_fc=True, fused_dropout_add_ln=True,
pad_vocab_size_multiple=8) pad_vocab_size_multiple=8)
model = GPTLMHeadModel(config) model = GPTLMHeadModel(config)
......
...@@ -7,9 +7,10 @@ defaults: ...@@ -7,9 +7,10 @@ defaults:
model: model:
config: config:
# n_positions is already set to ${datamodule.max_length} # n_positions is already set to ${datamodule.max_length}
residual_in_fp32: True
use_flash_attn: True use_flash_attn: True
fused_bias_fc: True fused_bias_fc: True
fused_dense_gelu_dense: True fused_mlp: True
fused_dropout_add_ln: True fused_dropout_add_ln: True
pad_vocab_size_multiple: 8 pad_vocab_size_multiple: 8
......
...@@ -7,9 +7,10 @@ defaults: ...@@ -7,9 +7,10 @@ defaults:
model: model:
config: config:
# n_positions is already set to ${datamodule.max_length} # n_positions is already set to ${datamodule.max_length}
residual_in_fp32: True
use_flash_attn: True use_flash_attn: True
fused_dropout_add_ln: True fused_dropout_add_ln: True
fused_dense_gelu_dense: True fused_mlp: True
fused_bias_fc: True fused_bias_fc: True
pad_vocab_size_multiple: 8 pad_vocab_size_multiple: 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment