Unverified Commit c88c9709 authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Benchmarks - Support TE FP8 in BERT/GPT2 models (#496)

Support Transformer Engine FP8 in existing PyTorch BERT/GPT2 models by
converting linear/layernorm to TE layers.
parent 8daef211
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
import torch import torch
import transformers import transformers
try:
import transformer_engine.pytorch as te
except ImportError:
te = None
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.distributed import TCPStore, PrefixStore from torch.distributed import TCPStore, PrefixStore
...@@ -44,6 +48,40 @@ def _set_force_fp32(self): ...@@ -44,6 +48,40 @@ def _set_force_fp32(self):
torch.backends.cuda.matmul.allow_tf32 = not self._args.force_fp32 torch.backends.cuda.matmul.allow_tf32 = not self._args.force_fp32
torch.backends.cudnn.allow_tf32 = not self._args.force_fp32 torch.backends.cudnn.allow_tf32 = not self._args.force_fp32
@torch.no_grad()
def _to_te_model(self, model):
"""Convert the input model to Transformer Engine model.
Replace all Linear/LayerNorm layers.
Modified based on Huggingface's utils `accelerate.accelerator.convert_model`, reference:
https://github.com/huggingface/accelerate/blob/v0.17.1/src/accelerate/utils/transformer_engine.py#L24
Args:
model (torch.nn.Module): Torch model.
"""
if not te:
return
for name, m in model.named_children():
if isinstance(m, torch.nn.Linear):
if any(p % 16 != 0 for p in m.weight.shape):
return
te_m = te.Linear(m.in_features, m.out_features, bias=(m.bias is not None))
te_m.weight.copy_(m.weight)
if m.bias is not None:
te_m.bias.copy_(m.bias)
setattr(model, name, te_m)
elif isinstance(m, torch.nn.LayerNorm):
te_m = te.LayerNorm(m.normalized_shape[0], eps=m.eps)
if hasattr(te_m, 'weight'):
te_m.weight.copy_(m.weight)
te_m.bias.copy_(m.bias)
else:
te_m.layer_norm_weight.copy_(m.weight)
te_m.layer_norm_bias.copy_(m.bias)
setattr(model, name, te_m)
else:
self._to_te_model(m)
def _init_distributed_setting(self): def _init_distributed_setting(self):
"""Initialize the distributed library and bind the worker to GPU. """Initialize the distributed library and bind the worker to GPU.
......
...@@ -47,58 +47,6 @@ def forward(self, input): ...@@ -47,58 +47,6 @@ def forward(self, input):
return result return result
class TeBertBenchmarkModel(torch.nn.Module):
"""BERT model using Transformer Engine."""
def __init__(self, config, num_classes):
"""Constructor.
Args:
config (BertConfig): Configurations of BERT model.
num_classes (int): The number of objects for classification.
"""
super().__init__()
self._embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
# Build BERT using nn.TransformerEncoderLayer or te.TransformerLayer
# input shape: (seq_len, batch_size, hidden_size)
self._encoder_layers = torch.nn.ModuleList(
[
te.TransformerLayer(
config.hidden_size,
config.intermediate_size,
config.num_attention_heads,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
layer_type='encoder',
) for _ in range(config.num_hidden_layers)
]
)
# BertPooler used in huggingface transformers
# https://github.com/huggingface/transformers/blob/accad48e/src/transformers/models/bert/modeling_bert.py#L893
self._pooler = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.Tanh(),
)
self._linear = torch.nn.Linear(config.hidden_size, num_classes)
def forward(self, input):
"""Forward propagation function.
Args:
input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
shape (batch_size, sequence_length).
Return:
out (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
(classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
"""
out = self._embedding(input.movedim(0, -1))
for layer in self._encoder_layers:
out = layer(out, attention_mask=None)
out = self._linear(self._pooler(out.movedim(0, 1)[:, 0]))
return out
class PytorchBERT(PytorchBase): class PytorchBERT(PytorchBase):
"""The BERT benchmark class.""" """The BERT benchmark class."""
def __init__(self, name, parameters=''): def __init__(self, name, parameters=''):
...@@ -183,15 +131,15 @@ def _create_model(self, precision): ...@@ -183,15 +131,15 @@ def _create_model(self, precision):
return False return False
try: try:
self._model = BertBenchmarkModel(self._config, self._args.num_classes)
if enable_fp8: if enable_fp8:
self._fp8_recipe = DelayedScaling( self._fp8_recipe = DelayedScaling(
fp8_format=Format[precision.name.strip('FP8_')], fp8_format=Format[precision.name.strip('FP8_')],
amax_history_len=16, amax_history_len=16,
amax_compute_algo='max', amax_compute_algo='max',
) )
self._model = TeBertBenchmarkModel(self._config, self._args.num_classes).to(dtype=torch.float16) self._to_te_model(self._model.to(dtype=torch.float16))
else: else:
self._model = BertBenchmarkModel(self._config, self._args.num_classes)
self._model = self._model.to(dtype=getattr(torch, precision.value)) self._model = self._model.to(dtype=getattr(torch, precision.value))
if self._gpu_available: if self._gpu_available:
self._model = self._model.cuda() self._model = self._model.cuda()
......
...@@ -5,6 +5,11 @@ ...@@ -5,6 +5,11 @@
import torch import torch
from transformers import GPT2Model, GPT2Config from transformers import GPT2Model, GPT2Config
try:
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
except ImportError:
te = None
from superbench.common.utils import logger from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision from superbench.benchmarks import BenchmarkRegistry, Precision
...@@ -23,7 +28,7 @@ def __init__(self, config, num_classes): ...@@ -23,7 +28,7 @@ def __init__(self, config, num_classes):
num_classes (int): The number of objects for classification. num_classes (int): The number of objects for classification.
""" """
super().__init__() super().__init__()
self._bert = GPT2Model(config) self._gpt2 = GPT2Model(config)
self._linear = torch.nn.Linear(config.hidden_size, num_classes) self._linear = torch.nn.Linear(config.hidden_size, num_classes)
def forward(self, input): def forward(self, input):
...@@ -37,7 +42,7 @@ def forward(self, input): ...@@ -37,7 +42,7 @@ def forward(self, input):
result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
(classification token) further processed by a Linear layer, shape (batch_size, hidden_size). (classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
""" """
outputs = self._bert(input) outputs = self._gpt2(input)
result = self._linear(outputs[0]) result = self._linear(outputs[0])
return result return result
...@@ -53,7 +58,13 @@ def __init__(self, name, parameters=''): ...@@ -53,7 +58,13 @@ def __init__(self, name, parameters=''):
""" """
super().__init__(name, parameters) super().__init__(name, parameters)
self._config = None self._config = None
self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16] self._fp8_recipe = None
self._supported_precision = [
Precision.FLOAT32,
Precision.FLOAT16,
Precision.FP8_HYBRID,
Precision.FP8_E4M3,
]
self._optimizer_type = Optimizer.ADAMW self._optimizer_type = Optimizer.ADAMW
self._loss_fn = torch.nn.CrossEntropyLoss() self._loss_fn = torch.nn.CrossEntropyLoss()
...@@ -99,9 +110,31 @@ def _create_model(self, precision): ...@@ -99,9 +110,31 @@ def _create_model(self, precision):
n_embd=self._args.hidden_size, n_layer=self._args.num_hidden_layers, n_head=self._args.num_attention_heads n_embd=self._args.hidden_size, n_layer=self._args.num_hidden_layers, n_head=self._args.num_attention_heads
) )
enable_fp8 = precision.name.startswith('FP8_')
if enable_fp8 and te is None:
logger.error(
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
' message: Cannot find transformer_engine.'
)
return False
if enable_fp8 and not self._gpu_available:
logger.error(
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
' message: FP8 is only supported on GPU.'
)
return False
try: try:
self._model = GPT2BenchmarkModel(self._config, self._args.num_classes) self._model = GPT2BenchmarkModel(self._config, self._args.num_classes)
self._model = self._model.to(dtype=getattr(torch, precision.value)) if enable_fp8:
self._fp8_recipe = DelayedScaling(
fp8_format=Format[precision.name.strip('FP8_')],
amax_history_len=16,
amax_compute_algo='max',
)
self._to_te_model(self._model.to(dtype=torch.float16))
else:
self._model = self._model.to(dtype=getattr(torch, precision.value))
if self._gpu_available: if self._gpu_available:
self._model = self._model.cuda() self._model = self._model.cuda()
except BaseException as e: except BaseException as e:
...@@ -136,7 +169,11 @@ def _train_step(self, precision): ...@@ -136,7 +169,11 @@ def _train_step(self, precision):
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._optimizer.zero_grad() self._optimizer.zero_grad()
output = self._model(sample) if self._fp8_recipe is not None:
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
output = self._model(sample)
else:
output = self._model(sample)
loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target) loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target)
loss.backward() loss.backward()
self._optimizer.step() self._optimizer.step()
...@@ -168,7 +205,11 @@ def _inference_step(self, precision): ...@@ -168,7 +205,11 @@ def _inference_step(self, precision):
start = self._timer() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._model(sample) if self._fp8_recipe is not None:
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
self._model(sample)
else:
self._model(sample)
end = self._timer() end = self._timer()
curr_step += 1 curr_step += 1
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment