"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1c37746892a5fd680e88264346197bb313c8dd08"
Unverified Commit 66d1eee6 authored by кѳѳsнī's avatar кѳѳsнī Committed by GitHub
Browse files

load_in_8bit now respects 'balanced' device maps in multi-gpu environments (#22377)

balanced 8bit memory
parent 8cfc6678
......@@ -2542,11 +2542,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) >= version.parse("0.37.0")
if isinstance(device_map, str):
special_dtypes = {
name: torch.float32
for name, _ in model.named_parameters()
if any(m in name for m in keep_in_fp32_modules)
}
special_dtypes = {}
if load_in_8bit:
special_dtypes.update(
{
name: torch_dtype
for name, _ in model.named_parameters()
if any(m in name for m in modules_to_not_convert)
}
)
special_dtypes.update(
{
name: torch.float32
for name, _ in model.named_parameters()
if any(m in name for m in keep_in_fp32_modules)
}
)
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
no_split_modules = model._no_split_modules
......@@ -2569,7 +2582,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if device_map != "sequential" and get_balanced_memory is not None:
max_memory = get_balanced_memory(
model,
dtype=torch_dtype,
dtype=torch_dtype if not load_in_8bit else torch.int8,
low_zero=(device_map == "balanced_low_0"),
max_memory=max_memory,
**kwargs,
......
......@@ -785,7 +785,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment