Commit 34e07701 authored by rprenger's avatar rprenger
Browse files

Merge branch 'main' into stop_tokens

parents 9b131fad f5345dfa
......@@ -11,7 +11,7 @@ Below are some of the projects where we have directly used Megatron:
* [Scaling Language Model Training to a Trillion Parameters Using Megatron](https://arxiv.org/pdf/2104.04473.pdf)
* [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf)
Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. The table below shows the model configurations along with the achieved FLOPs (both per GPU and aggregate over all GPUs). Note that the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.
Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. The table below shows the model configurations along with the achieved FLOPs (both per GPU and aggregate over all GPUs). Note that these results are from benchmark runs and these models were not trained to convergence; however, the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.
![Cases](images/cases_april2021.png)
......
......@@ -179,10 +179,6 @@ class ParallelAttention(MegatronModule):
init_method=output_layer_init_method,
skip_bias_add=True)
# Inference key-value memory
self.inference_key_memory = None
self.inference_value_memory = None
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
......@@ -203,19 +199,18 @@ class ParallelAttention(MegatronModule):
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params:
if inference_params.allocate_key_value_memory:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
self.inference_key_memory = self._allocate_memory(
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
self.inference_value_memory = self._allocate_memory(
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
# This is added for safety. In case inference_params
# is not provided, make sure there is no potential memory left
# from previous inference.
else:
self.inference_key_memory = None
self.inference_value_memory = None
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# =====================
......@@ -266,20 +261,18 @@ class ParallelAttention(MegatronModule):
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= self.inference_key_memory.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= self.inference_key_memory.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
self.inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end,
...] = key_layer
self.inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end,
...] = value_layer
key_layer = self.inference_key_memory[
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = self.inference_value_memory[
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
......
......@@ -119,7 +119,7 @@ def generate(model,
# Note that these tensors are broadcaseted to all ranks.
if torch.distributed.get_rank() == 0:
assert prompts is not None
#assert tokens_to_generate > 0
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
......
......@@ -40,7 +40,7 @@ class InferenceParams:
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.allocate_key_value_memory = True
self.key_value_memory_dict = {}
......@@ -132,11 +132,6 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
# Send output to the next stage.
send_to_next_pipeline_rank(output_tensor)
# Make sure we do not allocate context memory anymore.
if inference_params.allocate_key_value_memory:
inference_params.allocate_key_value_memory = False
return output_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment