"include/git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "54635d9f5d8591805b92e8e2bc38a70c5a465940"
Commit 78b7a1dc authored by Tri Dao's avatar Tri Dao
Browse files

[OPT] Load fp16 weights on CPU before moving to GPU

parent 33e0860c
...@@ -166,9 +166,10 @@ class GPTPreTrainedModel(nn.Module): ...@@ -166,9 +166,10 @@ class GPTPreTrainedModel(nn.Module):
""" """
# Instantiate model. # Instantiate model.
model = cls(config, *args, device=device, dtype=dtype, **kwargs) model = cls(config, *args, device=device, dtype=dtype, **kwargs)
# If we're going to shard the model, then don't load fp32 weights to GPU. # Load state_dict in cpu because we already initialized the model in GPU, and we don't
# want extra stuff taking up more GPU memory
state_dict = state_dict_from_pretrained( state_dict = state_dict_from_pretrained(
model_name, device=device if world_size == 1 else None, dtype=dtype model_name, device='cpu', dtype=dtype
) )
if model_name.startswith('gpt2'): if model_name.startswith('gpt2'):
state_dict = remap_state_dict_gpt2(state_dict, config) state_dict = remap_state_dict_gpt2(state_dict, config)
...@@ -178,7 +179,6 @@ class GPTPreTrainedModel(nn.Module): ...@@ -178,7 +179,6 @@ class GPTPreTrainedModel(nn.Module):
raise NotImplementedError(f'Model {model_name} not supported') raise NotImplementedError(f'Model {model_name} not supported')
if world_size > 1: if world_size > 1:
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
load_return = model.load_state_dict(state_dict, strict=strict) load_return = model.load_state_dict(state_dict, strict=strict)
logger.info(load_return) logger.info(load_return)
return model return model
......
...@@ -43,6 +43,8 @@ def remap_state_dict_opt(state_dict, config): ...@@ -43,6 +43,8 @@ def remap_state_dict_opt(state_dict, config):
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key) key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.', key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
r'transformer.layers.\1.norm1.', key) r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.', key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
......
...@@ -196,7 +196,7 @@ class DecodingCGCache: ...@@ -196,7 +196,7 @@ class DecodingCGCache:
@torch.inference_mode() @torch.inference_mode()
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
dtype=None): dtype=None, n_warmups=2):
if cache is None: if cache is None:
cache = DecodingCGCache() cache = DecodingCGCache()
param_example = next(iter(model.parameters())) param_example = next(iter(model.parameters()))
...@@ -228,7 +228,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p ...@@ -228,7 +228,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
if s_type not in cache.callables: if s_type not in cache.callables:
seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen) seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen)
cache.callables[s_type] = capture_graph( cache.callables[s_type] = capture_graph(
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool,
n_warmups=n_warmups
) )
def dispatch(input_ids, position_ids, seqlen): def dispatch(input_ids, position_ids, seqlen):
...@@ -239,7 +240,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p ...@@ -239,7 +240,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
return cache return cache
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None): def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None,
n_warmups=2):
assert max_seqlen >= seqlen_og assert max_seqlen >= seqlen_og
device = next(iter(model.parameters())).device device = next(iter(model.parameters())).device
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
...@@ -250,10 +252,15 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me ...@@ -250,10 +252,15 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
s = torch.cuda.Stream() s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream()) s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s): with torch.cuda.stream(s):
for _ in range(2): for _ in range(n_warmups):
logits = model(input_ids, position_ids=position_ids, logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1] inference_params=inference_params).logits[:, -1]
s.synchronize() s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# that's how I interpret the documentation). I'm not sure if this is required.
if torch.distributed.is_initialized():
torch.distributed.barrier()
torch.cuda.current_stream().wait_stream(s) torch.cuda.current_stream().wait_stream(s)
# Captures the graph # Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context # To allow capture, automatically sets a side stream as the current stream in the context
......
...@@ -7,6 +7,8 @@ from transformers.utils.hub import cached_file, get_checkpoint_shard_files ...@@ -7,6 +7,8 @@ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
def state_dict_from_pretrained(model_name, device=None, dtype=None): def state_dict_from_pretrained(model_name, device=None, dtype=None):
# If not fp32, then we don't want to load directly to the GPU
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
is_sharded = False is_sharded = False
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False) _raise_exceptions_for_missing_entries=False)
...@@ -25,9 +27,11 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None): ...@@ -25,9 +27,11 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
) )
state_dict = {} state_dict = {}
for sharded_file in resolved_archive_file: for sharded_file in resolved_archive_file:
state_dict.update(torch.load(sharded_file, map_location=device)) state_dict.update(torch.load(sharded_file, map_location=mapped_device))
else: else:
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device) state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
# Convert dtype before moving to GPU to save memory
if dtype is not None: if dtype is not None:
state_dict = {k: v.to(dtype) for k, v in state_dict.items()} state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
return state_dict return state_dict
...@@ -114,7 +114,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): ...@@ -114,7 +114,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"]) @pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-6.7b"]) # @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
def test_greedy_decode_opt(model_name): def test_greedy_decode_opt(model_name):
"""Check that our implementation of OPT generation matches the HF implementation: """Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
...@@ -145,7 +145,7 @@ def test_greedy_decode_opt(model_name): ...@@ -145,7 +145,7 @@ def test_greedy_decode_opt(model_name):
input_ids = tokenizer("Hello, my dog is cute and", input_ids = tokenizer("Hello, my dog is cute and",
return_tensors="pt").input_ids.to(device=device) return_tensors="pt").input_ids.to(device=device)
max_length = 30 max_length = 60
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40 # max_length = input_ids.shape[1] + 40
...@@ -192,7 +192,7 @@ def test_greedy_decode_opt(model_name): ...@@ -192,7 +192,7 @@ def test_greedy_decode_opt(model_name):
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
if verbose: if verbose:
print(out_cg.sequences) print(out_cg.sequences)
print(tokenizer.batch_decode(out.sequences.tolist())) print(tokenizer.batch_decode(out_cg.sequences.tolist()))
del model del model
......
...@@ -129,3 +129,5 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -129,3 +129,5 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
parallel_state.destroy_model_parallel()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment