"src/diffusers/schedulers/scheduling_dpm_cogvideox.py" did not exist on "769f0be8fb41daca9f3cbcffcfd0dbf01cc194b8"
Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import math
from deepspeed.utils import logger
from .constants import *
class CurriculumScheduler(object):
def __init__(self, config):
super().__init__()
self.state = {}
......@@ -16,17 +19,12 @@ class CurriculumScheduler(object):
f"Curriculum learning requires the config '{CURRICULUM_LEARNING_MAX_DIFFICULTY}'"
assert CURRICULUM_LEARNING_SCHEDULE_TYPE in config, \
f"Curriculum learning requires the config '{CURRICULUM_LEARNING_SCHEDULE_TYPE}'"
self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY] = config[
CURRICULUM_LEARNING_MIN_DIFFICULTY]
self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY] = config[
CURRICULUM_LEARNING_MAX_DIFFICULTY]
self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] = config[
CURRICULUM_LEARNING_MIN_DIFFICULTY]
self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] = config[
CURRICULUM_LEARNING_SCHEDULE_TYPE]
self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY] = config[CURRICULUM_LEARNING_MIN_DIFFICULTY]
self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY] = config[CURRICULUM_LEARNING_MAX_DIFFICULTY]
self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] = config[CURRICULUM_LEARNING_MIN_DIFFICULTY]
self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] = config[CURRICULUM_LEARNING_SCHEDULE_TYPE]
self.first_step = True
if config[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_DISCRETE:
if config[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_DISCRETE:
"""
The schedule_config is a list of difficulty and a list of max
step belonging to each difficulty. Example json config:
......@@ -43,18 +41,12 @@ class CurriculumScheduler(object):
f"Curriculum learning with fixed_discrete schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY}'"
assert CURRICULUM_LEARNING_SCHEDULE_MAX_STEP in config[CURRICULUM_LEARNING_SCHEDULE_CONFIG], \
f"Curriculum learning with fixed_discrete schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_MAX_STEP}'"
assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
[CURRICULUM_LEARNING_SCHEDULE_MAX_STEP]) > 0
assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
[CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY]) > 0
assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
[CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY]) == len(
config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
[CURRICULUM_LEARNING_SCHEDULE_MAX_STEP]) + 1
self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[
CURRICULUM_LEARNING_SCHEDULE_CONFIG]
elif config[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT:
assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_MAX_STEP]) > 0
assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY]) > 0
assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY]) == len(
config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_MAX_STEP]) + 1
self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
elif config[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT:
"""
The schedule_config includes:
total_curriculum_step: how many steps the curriculum learning takes to go
......@@ -79,15 +71,12 @@ class CurriculumScheduler(object):
f"Curriculum learning with fixed_root schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP}'"
assert CURRICULUM_LEARNING_SCHEDULE_ROOT_DEGREE in config[CURRICULUM_LEARNING_SCHEDULE_CONFIG], \
f"Curriculum learning with fixed_root schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_ROOT_DEGREE}'"
if config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][
CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP] % 8 != 0:
if config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP] % 8 != 0:
logger.warning(
f'When using seqlen metric, the difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your metric/hardware.'
)
self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[
CURRICULUM_LEARNING_SCHEDULE_CONFIG]
elif config[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_LINEAR:
self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
elif config[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_LINEAR:
"""
The schedule_config is the same as CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT but without the
root_degree.
......@@ -100,15 +89,12 @@ class CurriculumScheduler(object):
f"Curriculum learning with fixed_linear schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_TOTAL_STEP}'"
assert CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP in config[CURRICULUM_LEARNING_SCHEDULE_CONFIG], \
f"Curriculum learning with fixed_linear schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP}'"
if config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][
CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP] % 8 != 0:
if config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP] % 8 != 0:
logger.warning(
f'When using seqlen metric, the difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your metric/hardware.'
)
self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[
CURRICULUM_LEARNING_SCHEDULE_CONFIG]
elif config[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_CUSTOM:
self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
elif config[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_CUSTOM:
"""
Fully customized schedule. User need to provide a custom schedule
function by using the set_custom_curriculum_learning_schedule API
......@@ -145,38 +131,28 @@ class CurriculumScheduler(object):
s_state = self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
if root_degree is None:
root_degree = s_state[CURRICULUM_LEARNING_SCHEDULE_ROOT_DEGREE]
next_difficulty = (float(global_steps) /
s_state[CURRICULUM_LEARNING_SCHEDULE_TOTAL_STEP])**(
1.0 / root_degree)
next_difficulty = math.floor(next_difficulty *
(self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY] -
self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY]) +
self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY])
next_difficulty -= (next_difficulty %
s_state[CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP])
next_difficulty = min(next_difficulty,
self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY])
next_difficulty = (float(global_steps) / s_state[CURRICULUM_LEARNING_SCHEDULE_TOTAL_STEP])**(1.0 / root_degree)
next_difficulty = math.floor(
next_difficulty *
(self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY] - self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY]) +
self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY])
next_difficulty -= (next_difficulty % s_state[CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP])
next_difficulty = min(next_difficulty, self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY])
return next_difficulty
def get_difficulty(self, global_steps):
if self.state[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_DISCRETE:
if self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_DISCRETE:
return self.__fixed_discrete_get_difficulty(global_steps)
elif self.state[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_LINEAR:
elif self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_LINEAR:
return self.__fixed_root_get_difficulty(global_steps, 1)
elif self.state[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT:
elif self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT:
return self.__fixed_root_get_difficulty(global_steps)
elif self.state[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_CUSTOM:
elif self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_CUSTOM:
return self.custom_get_difficulty(global_steps)
else:
raise RuntimeError('Unsupported curriculum schedule type')
def update_difficulty(self, global_steps):
if self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] < self.state[
CURRICULUM_LEARNING_MAX_DIFFICULTY]:
self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] = self.get_difficulty(
global_steps)
if self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] < self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY]:
self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] = self.get_difficulty(global_steps)
return self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY]
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.utils import logger
from torch import Tensor
......@@ -14,6 +15,7 @@ class RandomLayerTokenDrop(Module):
"""
A layer wrapper for random LTD
"""
def __init__(self, layer: Module):
super(RandomLayerTokenDrop, self).__init__()
self.random_ltd_layer = layer
......@@ -52,9 +54,7 @@ class RandomLayerTokenDrop(Module):
elif self.model_type == 'decoder':
self.index_generator = gpt_sample_tokens
else:
logger.warning(
"************For now, we only support encoder-only or decoder-only models************"
)
logger.warning("************For now, we only support encoder-only or decoder-only models************")
raise NotImplementedError
def get_bsh(self, hidden_stats):
......@@ -78,40 +78,36 @@ class RandomLayerTokenDrop(Module):
self.curr_micro_batch, \
self.random_ltd_num_layer, \
hidden_states.device, mask)
self.random_ltd_scheduler.state[
RANDOM_LTD_SAMPLE_INDEX] = sampled_indices
self.random_ltd_scheduler.state[
RANDOM_LTD_ATTENTION_MASK] = part_attention_mask
self.random_ltd_scheduler.state[RANDOM_LTD_SAMPLE_INDEX] = sampled_indices
self.random_ltd_scheduler.state[RANDOM_LTD_ATTENTION_MASK] = part_attention_mask
else:
sampled_indices = self.random_ltd_scheduler.state[
RANDOM_LTD_SAMPLE_INDEX]
part_attention_mask = self.random_ltd_scheduler.state[
RANDOM_LTD_ATTENTION_MASK]
sampled_indices = self.random_ltd_scheduler.state[RANDOM_LTD_SAMPLE_INDEX]
part_attention_mask = self.random_ltd_scheduler.state[RANDOM_LTD_ATTENTION_MASK]
hidden_states, part_hidden_states = GatherTokens.apply(hidden_states, sampled_indices[self.random_ltd_layer_id,:,:], self.batch_first)
hidden_states, part_hidden_states = GatherTokens.apply(hidden_states,
sampled_indices[self.random_ltd_layer_id, :, :],
self.batch_first)
if self.mask_name is not None:
if self.model_type == 'encoder':
kwargs[self.mask_name] = part_attention_mask[
self.random_ltd_layer_id]
kwargs[self.mask_name] = part_attention_mask[self.random_ltd_layer_id]
else:
kwargs[self.mask_name] = part_attention_mask
outputs = self.random_ltd_layer(part_hidden_states, **kwargs)
if isinstance(outputs, tuple):
hidden_states = ScatterTokens.apply(hidden_states, outputs[0], sampled_indices[self.random_ltd_layer_id,:,:], self.batch_first)
hidden_states = ScatterTokens.apply(hidden_states, outputs[0],
sampled_indices[self.random_ltd_layer_id, :, :], self.batch_first)
my_list = list(outputs)
my_list[0] = hidden_states
return tuple(my_list)
elif isinstance(outputs, Tensor):
hidden_states = ScatterTokens.apply(hidden_states, outputs, sampled_indices[self.random_ltd_layer_id,:,:], self.batch_first)
hidden_states = ScatterTokens.apply(hidden_states, outputs,
sampled_indices[self.random_ltd_layer_id, :, :], self.batch_first)
return hidden_states
else:
logger.warning(
"************For now, we only support tuple and tensor output. \
You need to adjust the output according to the layer in your model************"
)
logger.warning("************For now, we only support tuple and tensor output. \
You need to adjust the output according to the layer in your model************")
raise NotImplementedError
else:
return self.random_ltd_layer(hidden_states, **kwargs)
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .basic_layer import RandomLayerTokenDrop
from collections import OrderedDict
......
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import math
......@@ -12,6 +13,7 @@ from ..constants import *
class BaseScheduler(object):
def __init__(self):
self.state = {}
......@@ -19,12 +21,9 @@ class BaseScheduler(object):
s_state = self.state[RANDOM_LTD_SCHEDULE_CONFIG]
if root_degree is None:
root_degree = s_state['root_degree']
next_seq = (float(global_steps) /
s_state[RANDOM_LTD_REQUIRE_STEP])**(1.0 / root_degree)
next_seq = math.floor(
next_seq *
(self.state[RANDOM_LTD_MAX_VALUE] - self.state[RANDOM_LTD_MIN_VALUE]) +
self.state[RANDOM_LTD_MIN_VALUE])
next_seq = (float(global_steps) / s_state[RANDOM_LTD_REQUIRE_STEP])**(1.0 / root_degree)
next_seq = math.floor(next_seq * (self.state[RANDOM_LTD_MAX_VALUE] - self.state[RANDOM_LTD_MIN_VALUE]) +
self.state[RANDOM_LTD_MIN_VALUE])
next_seq -= (next_seq % s_state[RANDOM_LTD_INCREASE_STEP])
next_seq = min(next_seq, self.state[RANDOM_LTD_MAX_VALUE])
return next_seq
......@@ -37,6 +36,7 @@ class BaseScheduler(object):
class RandomLTDScheduler(BaseScheduler):
def __init__(self, config):
super().__init__()
self.model_layer_num = config[RANDOM_LTD_TOTAL_LAYER_NUM]
......@@ -61,12 +61,9 @@ class RandomLTDScheduler(BaseScheduler):
if self.config_schedule is not None:
self.state[RANDOM_LTD_MIN_VALUE] = self.config_schedule[RANDOM_LTD_MIN_VALUE]
self.state[RANDOM_LTD_MAX_VALUE] = self.config_schedule[RANDOM_LTD_MAX_VALUE]
self.state[RANDOM_LTD_CURRENT_VALUE] = self.config_schedule[
RANDOM_LTD_MIN_VALUE]
self.state[RANDOM_LTD_SCHEDULE_CONFIG] = self.config_schedule[
RANDOM_LTD_SCHEDULE_CONFIG]
self.state[RANDOM_LTD_SCHEDULER_TYPE] = self.config_schedule[
RANDOM_LTD_SCHEDULER_TYPE]
self.state[RANDOM_LTD_CURRENT_VALUE] = self.config_schedule[RANDOM_LTD_MIN_VALUE]
self.state[RANDOM_LTD_SCHEDULE_CONFIG] = self.config_schedule[RANDOM_LTD_SCHEDULE_CONFIG]
self.state[RANDOM_LTD_SCHEDULER_TYPE] = self.config_schedule[RANDOM_LTD_SCHEDULER_TYPE]
self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = 0
self.state[RANDOM_LTD_CURR_STEP] = -1
......@@ -95,8 +92,7 @@ class RandomLTDScheduler(BaseScheduler):
def state_dict(self):
return {
RANDOM_LTD_CONSUMED_LAYER_TOKENS:
self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS],
RANDOM_LTD_CONSUMED_LAYER_TOKENS: self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS],
RANDOM_LTD_CURR_STEP: self.state[RANDOM_LTD_CURR_STEP],
RANDOM_LTD_CURRENT_VALUE: self.state[RANDOM_LTD_CURRENT_VALUE],
RANDOM_LTD_MIN_VALUE: self.state[RANDOM_LTD_MIN_VALUE],
......@@ -104,8 +100,7 @@ class RandomLTDScheduler(BaseScheduler):
}
def load_state_dict(self, state_dict):
self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = state_dict[
RANDOM_LTD_CONSUMED_LAYER_TOKENS]
self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = state_dict[RANDOM_LTD_CONSUMED_LAYER_TOKENS]
self.state[RANDOM_LTD_CURR_STEP] = state_dict[RANDOM_LTD_CURR_STEP]
self.state[RANDOM_LTD_CURRENT_VALUE] = state_dict[RANDOM_LTD_CURRENT_VALUE]
self.state[RANDOM_LTD_MIN_VALUE] = state_dict[RANDOM_LTD_MIN_VALUE]
......
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
......@@ -10,8 +11,7 @@ def bsh_decoder_gather(reserved_length, hidden_states, mask):
rand_list = []
part_hidden_states = [] # batch, seq, hidden ## different from megatron
for k in range(hidden_states.size(0)):
B_tmp = torch.randperm(hidden_states.size(1),
device=hidden_states.device)[:reserved_length]
B_tmp = torch.randperm(hidden_states.size(1), device=hidden_states.device)[:reserved_length]
B = B_tmp.sort()[0]
rand_list.append(B)
part_hidden_states.append(hidden_states[k:k + 1, B, :])
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
from collections import defaultdict
import csv
......@@ -16,6 +18,7 @@ from .utils import split_dataset, split_index, create_mmap_dataset_builder, clos
class DataAnalyzer(object):
def __init__(self,
dataset,
num_workers=1,
......@@ -53,25 +56,19 @@ class DataAnalyzer(object):
self.custom_map_finalize = custom_map_finalize
self.custom_reduce = custom_reduce
def init_metric_results(self,
thread_id,
metric_names,
metric_types,
metric_dtypes,
save_path,
worker_id):
def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id):
metric_results = []
for m_idx in range(len(metric_names)):
metric_name, metric_type, metric_dtype = metric_names[m_idx], \
metric_types[m_idx], metric_dtypes[m_idx]
assert metric_dtype not in [np.float64, np.double], "Currently floating point metric values are not supported. Please change your metric into integer values (and potentially multiply a larger coefficient to keep the precision)."
assert metric_dtype not in [
np.float64, np.double
], "Currently floating point metric values are not supported. Please change your metric into integer values (and potentially multiply a larger coefficient to keep the precision)."
metric_save_path = f"{save_path}/{metric_name}/worker{worker_id}_thread{thread_id}/"
os.makedirs(metric_save_path, exist_ok=True)
if metric_type == 'single_value_per_sample':
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
sample_to_metric_builder = create_mmap_dataset_builder(
sample_to_metric_fname,
metric_dtype)
sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_dtype)
metric_to_sample_fname = f"{metric_save_path}/{metric_name}_metric_to_sample"
os.system(f"rm -rf {metric_to_sample_fname}*")
metric_to_sample_dict = defaultdict(list)
......@@ -84,34 +81,25 @@ class DataAnalyzer(object):
elif metric_type == 'accumulate_value_over_samples':
metric_value = None
metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
metric_results.append({
"metric_value": metric_value,
"metric_value_fname": metric_value_fname
})
metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname})
return metric_results
def update_metric_results(self,
data,
metric_types,
metric_functions,
metric_results):
def update_metric_results(self, data, metric_types, metric_functions, metric_results):
for m_idx in range(len(metric_types)):
metric_type, metric_function, metric_result = metric_types[m_idx], \
metric_functions[m_idx], metric_results[m_idx]
if metric_type == 'single_value_per_sample':
metric_values = metric_function(data)
for row in range(metric_values.size()[0]):
metric_result["sample_to_metric_builder"].add_item(
metric_values[row].reshape(-1))
metric_result["metric_to_sample_dict"][
metric_values[row].item()].append(data['index'][row][0].item())
metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1))
metric_result["metric_to_sample_dict"][metric_values[row].item()].append(
data['index'][row][0].item())
for m_value in metric_result["metric_to_sample_dict"]:
if len(metric_result["metric_to_sample_dict"][m_value]) > 100:
metric_fname = metric_result["metric_to_sample_fname"]
with open(f"{metric_fname}_{m_value}.csv", 'a') as f:
writer = csv.writer(f)
writer.writerows(
[metric_result["metric_to_sample_dict"][m_value]])
writer.writerows([metric_result["metric_to_sample_dict"][m_value]])
metric_result["metric_to_sample_dict"][m_value] = []
elif metric_type == 'accumulate_value_over_samples':
metric_values = metric_function(data)
......@@ -126,25 +114,20 @@ class DataAnalyzer(object):
metric_dtypes[m_idx], metric_results[m_idx]
if metric_type == 'single_value_per_sample':
metric_fname = metric_result["sample_to_metric_fname"]
close_mmap_dataset_builder(metric_result["sample_to_metric_builder"],
metric_fname)
close_mmap_dataset_builder(metric_result["sample_to_metric_builder"], metric_fname)
for m_value in metric_result["metric_to_sample_dict"]:
if len(metric_result["metric_to_sample_dict"][m_value]) > 0:
metric_fname = metric_result["metric_to_sample_fname"]
with open(f"{metric_fname}_{m_value}.csv", 'a') as f:
writer = csv.writer(f)
writer.writerows(
[metric_result["metric_to_sample_dict"][m_value]])
writer.writerows([metric_result["metric_to_sample_dict"][m_value]])
metric_result["metric_to_sample_dict"][m_value] = []
elif metric_type == 'accumulate_value_over_samples':
if metric_result["metric_value"] is not None:
metric_value_builder = create_mmap_dataset_builder(
metric_result["metric_value_fname"],
metric_dtype)
metric_value_builder.add_item(
metric_result["metric_value"].reshape(-1))
close_mmap_dataset_builder(metric_value_builder,
metric_result["metric_value_fname"])
metric_value_builder = create_mmap_dataset_builder(metric_result["metric_value_fname"],
metric_dtype)
metric_value_builder.add_item(metric_result["metric_value"].reshape(-1))
close_mmap_dataset_builder(metric_value_builder, metric_result["metric_value_fname"])
def run_map_helper(self, thread_id):
start_idx, end_idx = self.thread_splits[thread_id][0], \
......@@ -152,15 +135,9 @@ class DataAnalyzer(object):
logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \
f"on data subset {start_idx} to {end_idx}")
thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx)))
sampler = BatchSampler(SequentialSampler(thread_dataset),
batch_size=self.batch_size,
drop_last=False)
sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False)
if self.collate_fn is None:
iterator = iter(
DataLoader(thread_dataset,
batch_sampler=sampler,
num_workers=0,
pin_memory=False))
iterator = iter(DataLoader(thread_dataset, batch_sampler=sampler, num_workers=0, pin_memory=False))
else:
iterator = iter(
DataLoader(thread_dataset,
......@@ -169,19 +146,11 @@ class DataAnalyzer(object):
collate_fn=self.collate_fn,
pin_memory=False))
if self.custom_map_init is None:
metric_results = self.init_metric_results(thread_id,
self.metric_names,
self.metric_types,
self.metric_dtypes,
self.save_path,
self.worker_id)
metric_results = self.init_metric_results(thread_id, self.metric_names, self.metric_types,
self.metric_dtypes, self.save_path, self.worker_id)
else:
metric_results = self.custom_map_init(thread_id,
self.metric_names,
self.metric_types,
self.metric_dtypes,
self.save_path,
self.worker_id)
metric_results = self.custom_map_init(thread_id, self.metric_names, self.metric_types, self.metric_dtypes,
self.save_path, self.worker_id)
total_sample = len(thread_dataset)
processed_sample = 0
start = time.time()
......@@ -189,15 +158,9 @@ class DataAnalyzer(object):
try:
data = next(iterator)
if self.custom_map_update is None:
self.update_metric_results(data,
self.metric_types,
self.metric_functions,
metric_results)
self.update_metric_results(data, self.metric_types, self.metric_functions, metric_results)
else:
self.custom_map_update(data,
self.metric_types,
self.metric_functions,
metric_results)
self.custom_map_update(data, self.metric_types, self.metric_functions, metric_results)
processed_sample += self.batch_size
duration = (time.time() - start) / 3600.0
remain_duration = duration * total_sample / processed_sample - duration
......@@ -206,22 +169,17 @@ class DataAnalyzer(object):
f"out of {total_sample} processed in {duration:.2f} hr, " \
f"estimated to finish in {remain_duration:.2f} hr")
except StopIteration:
logger.info(
f"worker {self.worker_id} thread {thread_id}: reach end of file")
logger.info(f"worker {self.worker_id} thread {thread_id}: reach end of file")
break
if self.custom_map_finalize is None:
self.finalize_metric_results(self.metric_types,
self.metric_dtypes,
metric_results)
self.finalize_metric_results(self.metric_types, self.metric_dtypes, metric_results)
else:
self.custom_map_finalize(self.metric_types,
self.metric_dtypes,
metric_results)
self.custom_map_finalize(self.metric_types, self.metric_dtypes, metric_results)
logger.info(f"worker {self.worker_id} thread {thread_id}: finished")
def run_map(self):
self.worker_splits, self.thread_splits = split_dataset(self.dataset,
self.num_workers, self.worker_id, self.num_threads)
self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.num_workers, self.worker_id,
self.num_threads)
if len(self.specific_threads) > 0:
threads_to_run = self.specific_threads
else:
......@@ -238,81 +196,50 @@ class DataAnalyzer(object):
assert self.num_threads == 1
self.run_map_helper(0)
def get_metric_value_percentiles(self,
metric_name,
num_sample_per_value,
total_num_samples):
def get_metric_value_percentiles(self, metric_name, num_sample_per_value, total_num_samples):
logger.info(f"Checking the value percentiles of metric {metric_name}...")
processed_samples = 0
current_percentile = 5
for key in sorted(num_sample_per_value.keys()):
processed_samples += num_sample_per_value[key]
if processed_samples >= total_num_samples * current_percentile / 100.0:
logger.info(
f"Metric {metric_name} {current_percentile}th percentile: {key}")
logger.info(f"Metric {metric_name} {current_percentile}th percentile: {key}")
current_percentile += 5
def merge_gather_map_stats(self,
num_workers,
num_threads,
num_threads_reduce,
t_idx_reduce,
metric_save_path,
metric_name,
return_dict):
def merge_gather_map_stats(self, num_workers, num_threads, num_threads_reduce, t_idx_reduce, metric_save_path,
metric_name, return_dict):
results = []
for w_idx in range(num_workers):
for t_idx in range(num_threads):
if (w_idx * num_threads + t_idx) % num_threads_reduce == t_idx_reduce:
w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric"
w_sample_to_metric = MMapIndexedDataset(w_sample_to_metric_fname,
skip_warmup=True)
w_sample_to_metric = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True)
unique_v = list(np.unique(w_sample_to_metric))
sample_to_metric_count = len(w_sample_to_metric)
logger.info(
f"Finished gathering map stats from worker {w_idx} thread {t_idx}."
)
logger.info(f"Finished gathering map stats from worker {w_idx} thread {t_idx}.")
results.append([unique_v, sample_to_metric_count])
return_dict[t_idx_reduce] = results
def merge_sample_to_metric(self,
t_idx_reduce,
metric_save_path,
metric_name,
metric_value_dtype,
def merge_sample_to_metric(self, t_idx_reduce, metric_save_path, metric_name, metric_value_dtype,
map_worker_thread):
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}"
sample_to_metric_builder = create_mmap_dataset_builder(
sample_to_metric_fname,
metric_value_dtype)
sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_value_dtype)
for w_t in map_worker_thread:
w_metric_save_path = f"{metric_save_path}/worker{w_t[0]}_thread{w_t[1]}/"
w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric"
w_data = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True)
for row in range(len(w_data)):
sample_to_metric_builder.add_item(
torch.tensor(w_data[row].astype(np.int64),
dtype=torch.long))
logger.info(
f"Finished merge_sample_to_metric from worker {w_t[0]} thread {w_t[1]}.")
sample_to_metric_builder.add_item(torch.tensor(w_data[row].astype(np.int64), dtype=torch.long))
logger.info(f"Finished merge_sample_to_metric from worker {w_t[0]} thread {w_t[1]}.")
close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname)
def merge_metric_to_sample(self,
t_idx_reduce,
metric_save_path,
metric_name,
sample_idx_dtype,
metric_value_dtype,
unique_metric_values,
num_workers,
num_threads):
def merge_metric_to_sample(self, t_idx_reduce, metric_save_path, metric_name, sample_idx_dtype, metric_value_dtype,
unique_metric_values, num_workers, num_threads):
index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}"
index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname,
sample_idx_dtype)
index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, sample_idx_dtype)
index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}"
index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname,
metric_value_dtype)
index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, metric_value_dtype)
for unique_v in unique_metric_values:
samples = []
for w_idx in range(num_workers):
......@@ -330,13 +257,7 @@ class DataAnalyzer(object):
close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
def merge_map_results(self,
dataset,
metric_names,
metric_types,
save_path,
num_workers,
num_threads,
def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_workers, num_threads,
num_threads_reduce):
total_num_samples = len(dataset)
sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1)
......@@ -385,9 +306,7 @@ class DataAnalyzer(object):
for w_idx in range(num_workers):
for t_idx in range(num_threads):
map_worker_thread.append([w_idx, t_idx])
thread_splits = split_index(0,
len(map_worker_thread),
num_threads_reduce)
thread_splits = split_index(0, len(map_worker_thread), num_threads_reduce)
p = []
for t_idx_reduce in range(num_threads_reduce):
start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
......@@ -405,24 +324,18 @@ class DataAnalyzer(object):
p[t_idx_reduce].join()
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
sample_to_metric_builder = create_mmap_dataset_builder(
sample_to_metric_fname,
metric_value_dtype)
sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_value_dtype)
for t_idx_reduce in range(num_threads_reduce):
chunk_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}"
logger.info(f"Merging file {chunk_fname}")
sample_to_metric_builder.merge_file_(chunk_fname)
close_mmap_dataset_builder(sample_to_metric_builder,
sample_to_metric_fname)
sample_to_metric = MMapIndexedDataset(sample_to_metric_fname,
skip_warmup=True)
close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname)
sample_to_metric = MMapIndexedDataset(sample_to_metric_fname, skip_warmup=True)
assert len(sample_to_metric) == total_num_samples
# metric_to_sample
unique_metric_values = list(sorted(unique_metric_values))
thread_splits = split_index(0,
len(unique_metric_values),
num_threads_reduce)
thread_splits = split_index(0, len(unique_metric_values), num_threads_reduce)
p = []
for t_idx_reduce in range(num_threads_reduce):
start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
......@@ -442,13 +355,9 @@ class DataAnalyzer(object):
for t_idx_reduce in range(num_threads_reduce):
p[t_idx_reduce].join()
index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample"
index_to_sample_builder = create_mmap_dataset_builder(
index_to_sample_fname,
sample_idx_dtype)
index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, sample_idx_dtype)
index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric"
index_to_metric_builder = create_mmap_dataset_builder(
index_to_metric_fname,
metric_value_dtype)
index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, metric_value_dtype)
for t_idx_reduce in range(num_threads_reduce):
chunk_is_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}"
logger.info(f"Merging file {chunk_is_fname}")
......@@ -456,43 +365,29 @@ class DataAnalyzer(object):
chunk_im_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}"
logger.info(f"Merging file {chunk_im_fname}")
index_to_metric_builder.merge_file_(chunk_im_fname)
close_mmap_dataset_builder(index_to_sample_builder,
index_to_sample_fname)
close_mmap_dataset_builder(index_to_metric_builder,
index_to_metric_fname)
close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
num_sample_per_value = {}
index_to_sample = MMapIndexedDataset(index_to_sample_fname,
skip_warmup=True)
index_to_metric = MMapIndexedDataset(index_to_metric_fname,
skip_warmup=True)
index_to_sample = MMapIndexedDataset(index_to_sample_fname, skip_warmup=True)
index_to_metric = MMapIndexedDataset(index_to_metric_fname, skip_warmup=True)
index_to_sample_merged_fname = f"{metric_save_path}/{metric_name}_index_to_sample_percentile_merged"
index_to_sample_merged_builder = create_mmap_dataset_builder(
index_to_sample_merged_fname,
sample_idx_dtype)
index_to_sample_merged_builder = create_mmap_dataset_builder(index_to_sample_merged_fname,
sample_idx_dtype)
for v_idx in range(len(index_to_sample)):
if v_idx > 0:
assert index_to_metric[v_idx] > index_to_metric[v_idx - 1]
num_sample_per_value[index_to_metric[v_idx][0]] = len(
index_to_sample[v_idx])
num_sample_per_value[index_to_metric[v_idx][0]] = len(index_to_sample[v_idx])
assert sum(num_sample_per_value.values()) == total_num_samples
merge_step = len(index_to_sample) // 100
for v_idx in range(0, len(index_to_sample), merge_step):
merged_samples = np.copy(
np.concatenate(
index_to_sample[v_idx:min(len(index_to_sample),
(v_idx + merge_step))],
axis=None))
np.concatenate(index_to_sample[v_idx:min(len(index_to_sample), (v_idx + merge_step))],
axis=None))
index_to_sample_merged_builder.add_item(
torch.tensor(merged_samples.astype(np.int64),
dtype=torch.long))
logger.info(
f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}."
)
close_mmap_dataset_builder(index_to_sample_merged_builder,
index_to_sample_merged_fname)
self.get_metric_value_percentiles(metric_name,
num_sample_per_value,
total_num_samples)
torch.tensor(merged_samples.astype(np.int64), dtype=torch.long))
logger.info(f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}.")
close_mmap_dataset_builder(index_to_sample_merged_builder, index_to_sample_merged_fname)
self.get_metric_value_percentiles(metric_name, num_sample_per_value, total_num_samples)
elif metric_type == 'accumulate_value_over_samples':
metric_save_path = f"{save_path}/{metric_name}/"
metric_value = None
......@@ -500,8 +395,7 @@ class DataAnalyzer(object):
for t_idx in range(num_threads):
w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
w_metric_value_fname = f"{w_metric_save_path}/{metric_name}_metric_value"
w_metric_value = MMapIndexedDataset(w_metric_value_fname,
skip_warmup=True)
w_metric_value = MMapIndexedDataset(w_metric_value_fname, skip_warmup=True)
if metric_value is None:
metric_value = np.copy(w_metric_value[0])
else:
......@@ -510,28 +404,14 @@ class DataAnalyzer(object):
value_min = int(min(metric_value))
metric_value_dtype = find_fit_int_dtype(value_min, value_max)
metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
metric_value_builder = create_mmap_dataset_builder(
metric_value_fname,
metric_value_dtype)
metric_value_builder.add_item(
torch.tensor(metric_value.astype(np.int64),
dtype=torch.long))
metric_value_builder = create_mmap_dataset_builder(metric_value_fname, metric_value_dtype)
metric_value_builder.add_item(torch.tensor(metric_value.astype(np.int64), dtype=torch.long))
close_mmap_dataset_builder(metric_value_builder, metric_value_fname)
def run_reduce(self):
if self.custom_reduce is None:
self.merge_map_results(self.dataset,
self.metric_names,
self.metric_types,
self.save_path,
self.num_workers,
self.num_threads,
self.num_threads_reduce)
self.merge_map_results(self.dataset, self.metric_names, self.metric_types, self.save_path,
self.num_workers, self.num_threads, self.num_threads_reduce)
else:
self.custom_reduce(self.dataset,
self.metric_names,
self.metric_types,
self.save_path,
self.num_workers,
self.num_threads,
self.num_threads_reduce)
self.custom_reduce(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers,
self.num_threads, self.num_threads_reduce)
'''
Copyright 2022 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
coding=utf-8
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Part of this code was adopted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/data_samplers.py
'''
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import torch
import os
......@@ -31,6 +34,7 @@ from .utils import create_mmap_dataset_builder, close_mmap_dataset_builder, find
class DeepSpeedDataSampler(object):
def __init__(self,
data_efficiency_config,
one_epoch_total_samples,
......@@ -45,8 +49,8 @@ class DeepSpeedDataSampler(object):
self.data_efficiency_config = data_efficiency_config
self.one_epoch_total_samples = one_epoch_total_samples
self.index_dtype = find_fit_int_dtype(0, one_epoch_total_samples)
self.total_samples = one_epoch_total_samples * self.data_efficiency_config[
DATA_SAMPLING][DATA_SAMPLING_NUM_EPOCHS]
self.total_samples = one_epoch_total_samples * self.data_efficiency_config[DATA_SAMPLING][
DATA_SAMPLING_NUM_EPOCHS]
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_group = data_parallel_group
......@@ -57,13 +61,11 @@ class DeepSpeedDataSampler(object):
self.gradient_accumulation_steps
self.global_rank = global_rank
self.drop_last = drop_last
self.np_rng = np.random.default_rng(
self.data_efficiency_config[DATA_EFFICIENCY_SEED])
self.np_rng = np.random.default_rng(self.data_efficiency_config[DATA_EFFICIENCY_SEED])
self.state = {}
self.batch = []
self.consumed_samples = 0
if self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING_ENABLED]:
if self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]:
self.curriculum_step = 0
self.current_difficulties = {}
self.data_cluster_paths = []
......@@ -77,33 +79,26 @@ class DeepSpeedDataSampler(object):
if self.global_rank == 0:
self.data_clusters = []
self.data_cluster_sizes = []
cluster_path = self.data_efficiency_config[DATA_SAMPLING][
CURRICULUM_LEARNING][CURRICULUM_LEARNING_CLUSTER_PATH]
cluster_path = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING_CLUSTER_PATH]
if not os.path.exists(cluster_path):
os.makedirs(cluster_path)
for metric in self.data_efficiency_config[DATA_SAMPLING][
CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]:
for metric in self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]:
self.curriculum_schedulers[metric] = CurriculumScheduler(
data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING]
[CURRICULUM_LEARNING_METRICS][metric])
self.difficulty_type[metric] = data_efficiency_config[DATA_SAMPLING][
CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS][metric][
CURRICULUM_LEARNING_DIFFICULTY_TYPE]
self.clustering_type[metric] = data_efficiency_config[DATA_SAMPLING][
CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS][metric][
CURRICULUM_LEARNING_CLUSTERING_TYPE]
data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS][metric])
self.difficulty_type[metric] = data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING_METRICS][metric][CURRICULUM_LEARNING_DIFFICULTY_TYPE]
self.clustering_type[metric] = data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING_METRICS][metric][CURRICULUM_LEARNING_CLUSTERING_TYPE]
if self.global_rank == 0:
if self.clustering_type[metric] != CURRICULUM_LEARNING_SINGLE_CLUSTER:
self.curriculum_index_to_sample[metric] = MMapIndexedDataset(
data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING]
[CURRICULUM_LEARNING_METRICS][metric]
[CURRICULUM_LEARNING_SAMPLE_PATH],
data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]
[metric][CURRICULUM_LEARNING_SAMPLE_PATH],
skip_warmup=True)
if self.difficulty_type[
metric] == CURRICULUM_LEARNING_VALUE_BASED:
if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED:
self.curriculum_index_to_metric[metric] = MMapIndexedDataset(
data_efficiency_config[DATA_SAMPLING]
[CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]
data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]
[metric][CURRICULUM_LEARNING_METRIC_PATH],
skip_warmup=True)
......@@ -122,8 +117,7 @@ class DeepSpeedDataSampler(object):
def set_custom_curriculum_learning_schedule(self, schedule_func_dict):
for metric in self.curriculum_schedulers:
if metric in schedule_func_dict:
self.curriculum_schedulers[metric].set_custom_get_difficulty(
schedule_func_dict[metric])
self.curriculum_schedulers[metric].set_custom_get_difficulty(schedule_func_dict[metric])
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
......@@ -133,26 +127,19 @@ class DeepSpeedDataSampler(object):
def get_sample_based_on_metric_value(self, metric, value_start, value_end):
new_samples = None
for row in range(len(self.curriculum_index_to_sample[metric])):
if self.curriculum_index_to_metric[metric][
row] <= value_end and self.curriculum_index_to_metric[metric][
row] > value_start:
if self.curriculum_index_to_metric[metric][row] <= value_end and self.curriculum_index_to_metric[metric][
row] > value_start:
row_samples = np.copy(self.curriculum_index_to_sample[metric][row])
new_samples = row_samples if new_samples is None else np.concatenate(
(new_samples,
row_samples),
axis=None)
(new_samples, row_samples), axis=None)
return new_samples
def get_sample_based_on_metric_percentile(self,
metric,
percentile_start,
percentile_end):
def get_sample_based_on_metric_percentile(self, metric, percentile_start, percentile_end):
new_samples = None
if self.data_1epoch_size is None:
self.data_1epoch_size = sum(
len(x) for x in self.curriculum_index_to_sample[metric])
max_percentile = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING_METRICS][metric][CURRICULUM_LEARNING_MAX_DIFFICULTY]
self.data_1epoch_size = sum(len(x) for x in self.curriculum_index_to_sample[metric])
max_percentile = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS][
metric][CURRICULUM_LEARNING_MAX_DIFFICULTY]
sample_per_percentile = self.data_1epoch_size // max_percentile
start_count = sample_per_percentile * percentile_start
end_count = sample_per_percentile * percentile_end
......@@ -167,12 +154,9 @@ class DeepSpeedDataSampler(object):
row_end = row_size
else:
row_end = end_count - current_count
row_samples = np.copy(
self.curriculum_index_to_sample[metric][row][row_start:row_end])
row_samples = np.copy(self.curriculum_index_to_sample[metric][row][row_start:row_end])
new_samples = row_samples if new_samples is None else np.concatenate(
(new_samples,
row_samples),
axis=None)
(new_samples, row_samples), axis=None)
current_count += row_size
if current_count >= end_count:
break
......@@ -193,63 +177,42 @@ class DeepSpeedDataSampler(object):
need_clustering += 1
if need_clustering > 1:
for metric in self.curriculum_schedulers:
if self.clustering_type[
metric] == CURRICULUM_LEARNING_SINGLE_CLUSTER:
if self.clustering_type[metric] == CURRICULUM_LEARNING_SINGLE_CLUSTER:
metric_cluster = np.arange(start=0,
stop=self.one_epoch_total_samples,
step=1,
dtype=self.index_dtype)
else:
if self.difficulty_type[
metric] == CURRICULUM_LEARNING_VALUE_BASED:
metric_cluster = self.get_sample_based_on_metric_value(
metric,
float('-inf'),
self.current_difficulties[metric])
elif self.difficulty_type[
metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED:
metric_cluster = self.get_sample_based_on_metric_value(metric, float('-inf'),
self.current_difficulties[metric])
elif self.difficulty_type[metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
metric_cluster = self.get_sample_based_on_metric_percentile(
metric,
0,
self.current_difficulties[metric])
metric, 0, self.current_difficulties[metric])
new_cluster = metric_cluster if new_cluster is None else \
np.intersect1d(new_cluster, metric_cluster, assume_unique=True)
for cluster in self.data_clusters:
new_cluster = np.setdiff1d(new_cluster,
cluster[0],
assume_unique=True)
new_cluster = np.setdiff1d(new_cluster, cluster[0], assume_unique=True)
else:
if len(self.data_clusters) == 0:
new_cluster = np.arange(start=0,
stop=self.one_epoch_total_samples,
step=1,
dtype=self.index_dtype)
new_cluster = np.arange(start=0, stop=self.one_epoch_total_samples, step=1, dtype=self.index_dtype)
for metric in self.curriculum_schedulers:
if self.clustering_type[metric] != CURRICULUM_LEARNING_SINGLE_CLUSTER:
if self.difficulty_type[
metric] == CURRICULUM_LEARNING_VALUE_BASED:
new_cluster = self.get_sample_based_on_metric_value(
metric,
previous_difficulties[metric],
self.current_difficulties[metric])
elif self.difficulty_type[
metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED:
new_cluster = self.get_sample_based_on_metric_value(metric, previous_difficulties[metric],
self.current_difficulties[metric])
elif self.difficulty_type[metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
new_cluster = self.get_sample_based_on_metric_percentile(
metric,
previous_difficulties[metric],
self.current_difficulties[metric])
metric, previous_difficulties[metric], self.current_difficulties[metric])
if new_cluster is not None and len(new_cluster) > 0:
logger.info(
f"new data cluster (previous_difficulties {previous_difficulties}, current_difficulties {self.current_difficulties}) with size {len(new_cluster)} generated."
)
self.np_rng.shuffle(new_cluster)
cluster_builder = create_mmap_dataset_builder(cluster_path,
self.index_dtype)
cluster_builder = create_mmap_dataset_builder(cluster_path, self.index_dtype)
cluster_builder.add_item_numpy(new_cluster)
close_mmap_dataset_builder(cluster_builder, cluster_path)
self.data_clusters.append(
MMapIndexedDataset(cluster_path,
skip_warmup=True))
self.data_clusters.append(MMapIndexedDataset(cluster_path, skip_warmup=True))
self.data_cluster_sizes.append(len(self.data_clusters[-1][0]))
else:
logger.info(
......@@ -264,10 +227,7 @@ class DeepSpeedDataSampler(object):
num_clusters = len(self.data_clusters)
weight_sum = sum(self.data_cluster_sizes)
weights = [x / weight_sum for x in self.data_cluster_sizes]
samples = self.np_rng.choice(num_clusters,
self.global_batch_size,
replace=True,
p=weights)
samples = self.np_rng.choice(num_clusters, self.global_batch_size, replace=True, p=weights)
samples = np.bincount(samples, minlength=num_clusters)
return samples
......@@ -285,8 +245,7 @@ class DeepSpeedDataSampler(object):
def get_sample_from_cluster(self, cidx, num_samples):
start_idx = self.data_cluster_current_position[cidx]
samples = list(
np.copy(self.data_clusters[cidx][0][start_idx:(start_idx + num_samples)]))
samples = list(np.copy(self.data_clusters[cidx][0][start_idx:(start_idx + num_samples)]))
self.data_cluster_current_position[cidx] += num_samples
if len(samples) < num_samples:
num_samples_remained = num_samples - len(samples)
......@@ -297,14 +256,12 @@ class DeepSpeedDataSampler(object):
return samples
def get_next_global_batch(self):
if self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING_ENABLED]:
if self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]:
self.curriculum_step += 1
new_cluster = False
previous_difficulties = {}
for metric in self.curriculum_schedulers:
next_difficulty = self.curriculum_schedulers[metric].update_difficulty(
self.curriculum_step)
next_difficulty = self.curriculum_schedulers[metric].update_difficulty(self.curriculum_step)
if metric not in self.current_difficulties or \
next_difficulty != self.current_difficulties[metric]:
new_cluster = True
......@@ -313,8 +270,7 @@ class DeepSpeedDataSampler(object):
else:
if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED:
previous_difficulties[metric] = float('-inf')
elif self.difficulty_type[
metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
elif self.difficulty_type[metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
previous_difficulties[metric] = 0
self.current_difficulties[metric] = next_difficulty
if new_cluster:
......@@ -323,12 +279,9 @@ class DeepSpeedDataSampler(object):
samples_per_cluster = self.sample_from_clusters()
batch = []
for cidx in range(len(samples_per_cluster)):
batch += self.get_sample_from_cluster(cidx,
samples_per_cluster[cidx])
batch += self.get_sample_from_cluster(cidx, samples_per_cluster[cidx])
self.np_rng.shuffle(batch)
batch = torch.tensor(batch,
device=get_accelerator().current_device_name(),
dtype=torch.long).view(-1)
batch = torch.tensor(batch, device=get_accelerator().current_device_name(), dtype=torch.long).view(-1)
else:
batch = torch.empty(self.global_batch_size,
device=get_accelerator().current_device_name(),
......@@ -356,8 +309,7 @@ class DeepSpeedDataSampler(object):
CURRICULUM_LEARNING_STEP: self.curriculum_step,
CURRICULUM_LEARNING_CURRENT_DIFFICULTIES: self.current_difficulties,
CURRICULUM_LEARNING_DATA_CLUSTER_PATHS: self.data_cluster_paths,
CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION:
self.data_cluster_current_position,
CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION: self.data_cluster_current_position,
CURRICULUM_LEARNING_NP_RNG_STATE: np.random.get_state()
}
......@@ -367,11 +319,10 @@ class DeepSpeedDataSampler(object):
self.curriculum_step = state_dict[CURRICULUM_LEARNING_STEP]
self.current_difficulties = state_dict[CURRICULUM_LEARNING_CURRENT_DIFFICULTIES]
self.data_cluster_paths = state_dict[CURRICULUM_LEARNING_DATA_CLUSTER_PATHS]
self.data_cluster_current_position = state_dict[
CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION]
self.data_cluster_current_position = state_dict[CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION]
np.random.set_state(state_dict[CURRICULUM_LEARNING_NP_RNG_STATE])
cluster_root_path = self.data_efficiency_config[DATA_SAMPLING][
CURRICULUM_LEARNING][CURRICULUM_LEARNING_CLUSTER_PATH]
cluster_root_path = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING_CLUSTER_PATH]
# Backward compatibility: previously data_cluster_paths were stored as
# absolute paths. Now we changed it to just the file name so that even
# if user moved the cluster files, the checkpoint loading still works
......@@ -379,12 +330,9 @@ class DeepSpeedDataSampler(object):
# in deepspeed json config.
for idx in range(len(self.data_cluster_paths)):
if '/' in self.data_cluster_paths[idx]:
self.data_cluster_paths[idx] = self.data_cluster_paths[idx].split(
'/')[-1]
self.data_cluster_paths[idx] = self.data_cluster_paths[idx].split('/')[-1]
if self.global_rank == 0:
for cluster_fname in self.data_cluster_paths:
cluster_path = f"{cluster_root_path}/{cluster_fname}"
self.data_clusters.append(
MMapIndexedDataset(cluster_path,
skip_warmup=True))
self.data_clusters.append(MMapIndexedDataset(cluster_path, skip_warmup=True))
self.data_cluster_sizes.append(len(self.data_clusters[-1][0]))
'''
Copyright 2022 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Part of this code was adopted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
'''
"""
# Copyright (c) Facebook, Inc. and its affiliates.
#
......@@ -50,16 +53,13 @@ def infer_dataset_impl(path):
return None
else:
print(f"Dataset does not exist: {path}")
print(
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
)
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
def make_builder(out_file, impl, vocab_size=None):
if impl == 'mmap':
return MMapIndexedDatasetBuilder(out_file,
dtype=__best_fitting_dtype(vocab_size))
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
else:
return IndexedDatasetBuilder(out_file)
......@@ -67,9 +67,7 @@ def make_builder(out_file, impl, vocab_size=None):
def make_dataset(path, impl, skip_warmup=False):
if not IndexedDataset.exists(path):
print(f"Dataset does not exist: {path}")
print(
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
)
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
if impl == 'infer':
impl = infer_dataset_impl(path)
......@@ -150,10 +148,8 @@ class IndexedDataset(torch.utils.data.Dataset):
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.')
version = f.read(8)
assert struct.unpack('<Q', version) == (1, )
code, self.element_size = struct.unpack('<QQ', f.read(16))
......@@ -212,8 +208,7 @@ class IndexedDataset(torch.utils.data.Dataset):
@staticmethod
def exists(path):
return (os.path.exists(index_file_path(path))
and os.path.exists(data_file_path(path)))
return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
@property
def supports_prefetch(self):
......@@ -221,6 +216,7 @@ class IndexedDataset(torch.utils.data.Dataset):
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path):
super().__init__(path)
self.cache = None
......@@ -273,15 +269,7 @@ class IndexedCachedDataset(IndexedDataset):
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float64: 4,
np.double: 8
}
element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float64: 4, np.double: 8}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
......@@ -379,12 +367,15 @@ def get_pointers_with_total(sizes, elemsize, dtype):
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00'
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, 'wb')
......@@ -430,10 +421,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __init__(self, path, skip_warmup=False):
with open(path, 'rb') as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
assert self._HDR_MAGIC == magic_test, ('Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.')
version = struct.unpack('<Q', stream.read(8))
assert (1, ) == version
......@@ -452,10 +441,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print(" reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer,
dtype=np.int32,
count=self._len,
offset=offset)
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
print(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer,
dtype=np.int64,
......@@ -465,8 +451,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._doc_idx = np.frombuffer(self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes +
self._pointers.nbytes)
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
......@@ -514,9 +499,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
print(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path),
mode='r',
order='C')
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
print(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
......@@ -532,10 +515,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer,
dtype=self._index.dtype,
count=size,
offset=ptr)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
......@@ -545,10 +525,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer,
dtype=self._index.dtype,
count=total_size,
offset=ptr)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1])
return sents
......@@ -562,10 +539,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer,
dtype=self._index.dtype,
count=length,
offset=ptr)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr)
return np_array
@property
......@@ -591,8 +565,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
@staticmethod
def exists(path):
return (os.path.exists(index_file_path(path))
and os.path.exists(data_file_path(path)))
return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
@property
def dtype(self):
......@@ -600,6 +573,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb')
self._dtype = dtype
......@@ -626,9 +600,7 @@ class MMapIndexedDatasetBuilder(object):
assert index.dtype == self._dtype
total_len = len(index.sizes) + len(self._sizes)
print(
f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}"
)
print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}")
offset = len(self._sizes)
self._sizes.extend(index.sizes)
......
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import math
import numpy as np
......@@ -32,19 +33,14 @@ def find_fit_int_dtype(min_value, max_value):
def split_index(start_idx, end_idx, num_partitions):
partition_size = math.ceil((end_idx - start_idx) / num_partitions)
partitions = [[
start_idx + x * partition_size,
min(end_idx,
start_idx + (x + 1) * partition_size)
] for x in range(num_partitions)]
partitions = [[start_idx + x * partition_size,
min(end_idx, start_idx + (x + 1) * partition_size)] for x in range(num_partitions)]
return partitions
def split_dataset(dataset, num_workers, worker_id, num_threads):
worker_splits = split_index(0, len(dataset), num_workers)
thread_splits = split_index(worker_splits[worker_id][0],
worker_splits[worker_id][1],
num_threads)
thread_splits = split_index(worker_splits[worker_id][0], worker_splits[worker_id][1], num_threads)
return worker_splits, thread_splits
......
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
......@@ -14,6 +15,7 @@ from deepspeed.runtime.constants import GRADIENT_ACCUMULATION_STEPS, \
class RepeatingLoader:
def __init__(self, loader):
"""Wraps an iterator to allow for infinite iteration. This is especially useful
for DataLoader types that we wish to automatically restart upon completion.
......@@ -37,6 +39,7 @@ class RepeatingLoader:
class DeepSpeedDataLoader(object):
def __init__(self,
dataset,
batch_size,
......@@ -55,30 +58,26 @@ class DeepSpeedDataLoader(object):
self.batch_size = batch_size
self.curriculum_learning_enabled = False
if CURRICULUM_LEARNING in deepspeed_dataloader_config:
self.curriculum_learning_enabled = deepspeed_dataloader_config[
CURRICULUM_LEARNING]
self.curriculum_learning_enabled = deepspeed_dataloader_config[CURRICULUM_LEARNING]
if self.curriculum_learning_enabled:
data_sampler = DeepSpeedDataSampler(
self.deepspeed_dataloader_config[DATA_EFFICIENCY],
len(dataset),
self.batch_size,
data_parallel_rank,
data_parallel_world_size,
self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP],
self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS],
self.deepspeed_dataloader_config[GLOBAL_RANK],
drop_last=dataloader_drop_last)
data_sampler = DeepSpeedDataSampler(self.deepspeed_dataloader_config[DATA_EFFICIENCY],
len(dataset),
self.batch_size,
data_parallel_rank,
data_parallel_world_size,
self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP],
self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS],
self.deepspeed_dataloader_config[GLOBAL_RANK],
drop_last=dataloader_drop_last)
device_count = get_accelerator().device_count()
num_local_io_workers = self.deepspeed_dataloader_config[
DATA_SAMPLING_NUM_WORKERS]
num_local_io_workers = self.deepspeed_dataloader_config[DATA_SAMPLING_NUM_WORKERS]
else:
if local_rank >= 0:
if data_sampler is None:
data_sampler = DistributedSampler(
dataset=dataset,
num_replicas=data_parallel_world_size,
rank=data_parallel_rank)
data_sampler = DistributedSampler(dataset=dataset,
num_replicas=data_parallel_world_size,
rank=data_parallel_rank)
device_count = 1
else:
if data_sampler is None:
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.utils import log_dist
......@@ -7,6 +10,7 @@ import logging
class Eigenvalue(object):
def __init__(self,
verbose=False,
max_iter=100,
......@@ -77,8 +81,7 @@ class Eigenvalue(object):
]
else:
v = [
torch.randn(p.size(),
device=device) for p in model_block.parameters()
torch.randn(p.size(), device=device) for p in model_block.parameters()
if p.grad is not None and p.grad.grad_fn is not None
]
torch.random.set_rng_state(rng_state)
......@@ -100,24 +103,18 @@ class Eigenvalue(object):
# Disable eigenvalue if the model doesn't support second order gradients computation,
# e.g. when enabling DS transformer kernel.
if len(grads) == 0 or len(params) == 0:
log_dist(f'The model does NOT support eigenvalue computation.',
ranks=[0],
level=logging.WARNING)
log_dist(f'The model does NOT support eigenvalue computation.', ranks=[0], level=logging.WARNING)
return []
i = 0
eigenvalue_current, eigenvalue_previous = 1., 0.
while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs(
(eigenvalue_current - eigenvalue_previous) /
eigenvalue_current) >= self.tol): # test convergence criteria
(eigenvalue_current - eigenvalue_previous) / eigenvalue_current) >=
self.tol): # test convergence criteria
eigenvalue_previous = eigenvalue_current
Hv = torch.autograd.grad(grads,
params,
grad_outputs=v,
only_inputs=True,
retain_graph=True)
Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True)
#Hv = [hv.float() for hv in Hv]
Hv = [self.nan_to_num(hv).float() for hv in Hv]
......@@ -131,9 +128,7 @@ class Eigenvalue(object):
block_eigenvalue.append(eigenvalue_current)
if self.verbose:
log_dist(
f'block: {block}, power iteration: {i}, eigenvalue: {eigenvalue_current}',
ranks=[0])
log_dist(f'block: {block}, power iteration: {i}, eigenvalue: {eigenvalue_current}', ranks=[0])
block_eigenvalue = self.post_process(block_eigenvalue)
......
"""
Copyright 2019 The Microsoft DeepSpeed Team
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import re
......@@ -9,6 +10,7 @@ import torch
import hashlib
from collections import defaultdict, OrderedDict, deque
from shutil import copyfile
import gc
from torch.nn.modules import Module
from torch.nn.parameter import Parameter
......@@ -19,6 +21,7 @@ from typing import Callable, Dict, Union, Iterable
import deepspeed
from deepspeed import comm as dist
from deepspeed.runtime.utils import see_memory_usage, DummyOptim
from .zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
......@@ -31,7 +34,7 @@ from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER
......@@ -53,7 +56,7 @@ from deepspeed.compression.constants import \
WEIGHT_QUANTIZE_ROUNDING, \
WEIGHT_QUANTIZE_VERBOSE, \
WEIGHT_QUANTIZE_KERNEL
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS
from deepspeed.runtime.sparse_tensor import SparseTensor
from deepspeed.runtime import lr_schedules
......@@ -79,7 +82,7 @@ from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayer
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from .pipe.module import PipelineModule
from .utils import ensure_directory_exists, get_ma_status
from .utils import get_ma_status
from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE
......@@ -92,10 +95,7 @@ from deepspeed.utils.logging import print_json_dist, print_configuration
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.inference.config import DtypeEnum
# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None
from deepspeed.runtime.config import DtypeEnum
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
......@@ -110,16 +110,13 @@ try:
except ImportError:
# Fail silently so we don't spam logs unnecessarily if user isn't using amp
APEX_INSTALLED = False
pass
def split_half_float_double_sparse(tensors):
device_type = get_accelerator().device_name()
supported_types = [
"torch.{}.HalfTensor".format(device_type),
"torch.{}.FloatTensor".format(device_type),
"torch.{}.DoubleTensor".format(device_type),
"torch.{}.BFloat16Tensor".format(device_type),
"torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type),
"torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type),
SparseTensor.type()
]
......@@ -148,6 +145,7 @@ STEP_GLOBAL_TIMER = 'step'
class EngineTimers(object):
r"""Wallclock timers for DeepSpeedEngine"""
def __init__(self, enable_micro_timers, enable_global_timers):
self.forward_timers = []
self.backward_timers = []
......@@ -164,10 +162,7 @@ class EngineTimers(object):
self.backward_reduce_timers += [BACKWARD_REDUCE_MICRO_TIMER]
self.step_timers += [STEP_MICRO_TIMER]
self.micro_timers += [
FORWARD_MICRO_TIMER,
BACKWARD_MICRO_TIMER,
BACKWARD_INNER_MICRO_TIMER,
BACKWARD_REDUCE_MICRO_TIMER,
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER,
STEP_MICRO_TIMER
]
......@@ -178,16 +173,14 @@ class EngineTimers(object):
self.backward_reduce_timers += [BACKWARD_REDUCE_GLOBAL_TIMER]
self.step_timers += [STEP_GLOBAL_TIMER]
self.global_timers += [
FORWARD_GLOBAL_TIMER,
BACKWARD_GLOBAL_TIMER,
BACKWARD_INNER_GLOBAL_TIMER,
BACKWARD_REDUCE_GLOBAL_TIMER,
FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER,
STEP_GLOBAL_TIMER
]
class DeepSpeedEngine(Module):
r"""DeepSpeed engine for training."""
def __init__(
self,
args,
......@@ -200,7 +193,7 @@ class DeepSpeedEngine(Module):
dist_init_required=None,
collate_fn=None,
config=None,
config_params=None,
config_class=None,
dont_change_device=False,
):
super(DeepSpeedEngine, self).__init__()
......@@ -218,6 +211,7 @@ class DeepSpeedEngine(Module):
self.gradient_average = True
self.warn_unscaled_loss = True
self.config = config
self._config = config_class
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True
......@@ -236,8 +230,6 @@ class DeepSpeedEngine(Module):
self.checkpoint_engine = None
global dist
from deepspeed import comm as dist
self._is_gradient_accumulation_boundary = None
self.scale_wrt_gas = None
......@@ -247,38 +239,15 @@ class DeepSpeedEngine(Module):
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}
# Set config using config_params for backwards compat
if self.config is None and config_params is not None:
self.config = config_params
from deepspeed.comm import supported_torch_version
# This supported_torch_version check is for torch1.2 compatibility only
if supported_torch_version:
dist.init_distributed(dist_backend=self.dist_backend,
dist_init_required=dist_init_required)
else:
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
if dist_init_required is False:
assert (
dist.is_initialized() is True
), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
else:
if not dist.is_initialized():
dist.init_process_group(backend=self.dist_backend)
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
see_memory_usage(f"DeepSpeed Engine: After args sanity test",
force=self.memory_breakdown())
see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
if mpu is not None:
if self.elasticity_enabled():
if not self.is_elastic_model_parallel_supported():
assert not self.elasticity_enabled(), (
"Elasticity is not currently supported" " with model parallelism."
)
assert not self.elasticity_enabled(), ("Elasticity is not currently supported"
" with model parallelism.")
self._set_distributed_vars(args)
......@@ -309,8 +278,7 @@ class DeepSpeedEngine(Module):
monitor_memory=False,
)
log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}",
ranks=[0])
log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", ranks=[0])
if self.flops_profiler_enabled():
self.flops_profiler = FlopsProfiler(self.module, self)
......@@ -332,6 +300,10 @@ class DeepSpeedEngine(Module):
if model_parameters is None:
model_parameters = self.module.parameters()
# Convert model parameters from generator to list
if not isinstance(model_parameters, list):
model_parameters = list(model_parameters)
if has_optimizer:
self._configure_optimizer(optimizer, model_parameters)
self._configure_lr_scheduler(lr_scheduler)
......@@ -346,12 +318,9 @@ class DeepSpeedEngine(Module):
self.sparse_tensor_module_names = set()
# if self.sparse_gradients_enabled():
for name, module in self.module.named_modules():
if isinstance(module,
(torch.nn.Embedding,
torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled():
if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled():
self.sparse_tensor_module_names.add(name + ".weight")
logger.info(
"Will convert {} to sparse tensor during training".format(name))
logger.info("Will convert {} to sparse tensor during training".format(name))
self.save_non_zero_checkpoint = False
self.save_zero_checkpoint = False
......@@ -365,23 +334,19 @@ class DeepSpeedEngine(Module):
self.progressive_layer_drop = self._configure_progressive_layer_drop()
if self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy(
)
self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy()
if self.random_ltd_enabled():
random_ltd_config = self.random_ltd_config()
random_ltd_config[RANDOM_LTD_GLOBAL_BATCH_SIZE] = self.train_batch_size()
random_ltd_config[
RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu()
self.random_ltd_scheduler = self._configure_random_ltd_scheduler(
random_ltd_config)
random_ltd_config[RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu()
self.random_ltd_scheduler = self._configure_random_ltd_scheduler(random_ltd_config)
# Engine timers
self.engine_timers = EngineTimers(
enable_micro_timers=self.wall_clock_breakdown(),
enable_global_timers=self.wall_clock_breakdown()
or self.flops_profiler_enabled())
self.engine_timers = EngineTimers(enable_micro_timers=self.wall_clock_breakdown(),
enable_global_timers=self.wall_clock_breakdown()
or self.flops_profiler_enabled())
if self.global_rank == 0:
self._config.print("DeepSpeedEngine configuration")
......@@ -414,10 +379,8 @@ class DeepSpeedEngine(Module):
if p.requires_grad:
trainable_num_params += n
if self.global_rank == 0:
self.autotuning_model_info[
"num_params"] = num_params * self.mp_world_size
self.autotuning_model_info[
"trainable_num_params"] = trainable_num_params * self.mp_world_size
self.autotuning_model_info["num_params"] = num_params * self.mp_world_size
self.autotuning_model_info["trainable_num_params"] = trainable_num_params * self.mp_world_size
logger.info(f"model parameter = {num_params}")
......@@ -447,13 +410,10 @@ class DeepSpeedEngine(Module):
ValueError: if ``train_batch_size`` is not divisible by the
configured micro-batch size and data parallelism.
"""
if train_batch_size % (self.train_micro_batch_size_per_gpu() *
self.dp_world_size) != 0:
if train_batch_size % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) != 0:
#print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}')
raise ValueError(
f'Train batch size must be divisible by micro-batch data parallelism')
new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() *
self.dp_world_size)
raise ValueError(f'Train batch size must be divisible by micro-batch data parallelism')
new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * self.dp_world_size)
# overwrite config
self._config.train_batch_size = train_batch_size
self._config.gradient_accumulation_steps = new_gas
......@@ -464,8 +424,7 @@ class DeepSpeedEngine(Module):
def set_custom_curriculum_learning_schedule(self, schedule_func_dict):
if self.training_dataloader is not None and self.curriculum_learning_enabled():
self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule(
schedule_func_dict)
self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule(schedule_func_dict)
def get_global_grad_norm(self) -> float:
"""Return the 2-norm of all gradients. If there is model parallelism,
......@@ -492,8 +451,7 @@ class DeepSpeedEngine(Module):
elif name in dir(_module):
return getattr(_module, name)
else:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'")
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled
......@@ -567,15 +525,13 @@ class DeepSpeedEngine(Module):
return self._config.data_efficiency_config[DATA_SAMPLING]
def curriculum_learning_enabled(self):
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING_ENABLED]
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]
def curriculum_learning_config(self):
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING]
def random_ltd_enabled(self):
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][
RANDOM_LTD_ENABLED]
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][RANDOM_LTD_ENABLED]
def random_ltd_config(self):
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD]
......@@ -583,26 +539,20 @@ class DeepSpeedEngine(Module):
def random_ltd_initialize(self):
assert self.random_ltd_enabled()
random_ltd_config = self.random_ltd_config()
random_ltd_queue = deque(
[x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])])
random_ltd_queue = deque([x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])])
count = 0
for name, layer in self.module.named_modules():
if isinstance(layer, RandomLayerTokenDrop):
if len(random_ltd_queue) != 0 and str(
random_ltd_queue[0]) in name: ###[1,2,3]
layer.init_config(random_ltd_config,
self.random_ltd_scheduler,
count)
if len(random_ltd_queue) != 0 and str(random_ltd_queue[0]) in name: ###[1,2,3]
layer.init_config(random_ltd_config, self.random_ltd_scheduler, count)
random_ltd_queue.popleft()
count += 1
if random_ltd_config[RANDOM_LTD_LAYER_NUM] != count:
raise ValueError(
f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \
raise ValueError(f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \
equivalent to the len of random_ltd_layer_id {count}')
if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][
RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]:
if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]:
assert self.client_lr_scheduler is None
raise ValueError(f'not yet support')
#self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler)
......@@ -663,8 +613,7 @@ class DeepSpeedEngine(Module):
def autotuning_profile_model_info(self):
return self.autotuning_enabled(
) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get(
"profile",
False)
"profile", False)
def sparse_gradients_enabled(self):
return self._config.sparse_gradients_enabled
......@@ -676,8 +625,7 @@ class DeepSpeedEngine(Module):
return self._config.train_micro_batch_size_per_gpu
def optimizer_name(self):
return (self.client_optimizer.__class__.__name__
if self.client_optimizer else self._config.optimizer_name)
return (self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name)
def optimizer_params(self):
return self._config.optimizer_params
......@@ -695,22 +643,15 @@ class DeepSpeedEngine(Module):
return (
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_ENABLED],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_GROUPS],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_GROUPS],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_CHANGE_RATIO],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_TYPE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_ROUNDING],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_VERBOSE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_KERNEL],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_CHANGE_RATIO],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_TYPE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ROUNDING],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_VERBOSE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_KERNEL],
)
def zero_optimization(self):
......@@ -719,6 +660,9 @@ class DeepSpeedEngine(Module):
def zero_allow_untested_optimizer(self):
return self._config.zero_allow_untested_optimizer
def zero_force_ds_cpu_optimizer(self):
return self._config.zero_force_ds_cpu_optimizer
def zero_reduce_scatter(self):
return self._config.zero_config.reduce_scatter
......@@ -733,10 +677,7 @@ class DeepSpeedEngine(Module):
def zero_use_cpu_optimizer(self):
if self._config.zero_config.offload_optimizer is not None:
return self._config.zero_config.offload_optimizer.device in [
OffloadDeviceEnum.cpu,
OffloadDeviceEnum.nvme
]
return self._config.zero_config.offload_optimizer.device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]
return False
def zero_cpu_offload(self):
......@@ -750,6 +691,9 @@ class DeepSpeedEngine(Module):
def zero_optimization_stage(self):
return self._config.zero_optimization_stage
def mics_shard_size(self):
return self._config.mics_shard_size
def zero_reduce_bucket_size(self):
return self._config.zero_config.reduce_bucket_size
......@@ -833,9 +777,11 @@ class DeepSpeedEngine(Module):
res = self._config.communication_data_type
if res is not None:
return res
elif self.fp16_enabled() or self.zero_optimization_stage():
if self.fp16_enabled():
return torch.float16
elif self.bfloat16_enabled():
if self.bfloat16_enabled():
return torch.bfloat16
return torch.float32
......@@ -897,14 +843,11 @@ class DeepSpeedEngine(Module):
# First check for scheduler in json configuration
lr_scheduler = self._scheduler_from_config(self.optimizer)
if lr_scheduler:
log_dist(
f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}",
ranks=[0])
log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0])
self.lr_scheduler = lr_scheduler
else:
if isinstance(client_lr_scheduler, Callable):
log_dist('DeepSpeed using client callable to create LR scheduler',
ranks=[0])
log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0])
self.lr_scheduler = client_lr_scheduler(self.basic_optimizer)
else:
log_dist('DeepSpeed using client LR scheduler', ranks=[0])
......@@ -919,12 +862,9 @@ class DeepSpeedEngine(Module):
try:
from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \
NebulaCheckpointEngine
self.checkpoint_engine = NebulaCheckpointEngine(
config_params=self._config.nebula_config)
self.checkpoint_engine = NebulaCheckpointEngine(config_params=self._config.nebula_config)
except ImportError as err:
logger.error(
f"No torch_nebula was found! Will fall back to torch.save. Details: {err}"
)
logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}")
self.checkpoint_engine = TorchCheckpointEngine()
dp_rank = self.global_rank
......@@ -936,8 +876,7 @@ class DeepSpeedEngine(Module):
# only the first data parallel process needs to store the model checkpoint
# if you want to use node local storage this must be done by rank 0 on each
# node
self.save_non_zero_checkpoint = (
rank == 0) or self.zero_optimization_partition_weights()
self.save_non_zero_checkpoint = (rank == 0) or self.zero_optimization_partition_weights()
if self.zero_optimization() or self.bfloat16_enabled():
param_rank = dist.get_rank(group=self.optimizer.dp_process_group)
......@@ -952,9 +891,8 @@ class DeepSpeedEngine(Module):
if hasattr(lr_schedules, scheduler_name):
scheduler = getattr(lr_schedules, scheduler_name)
else:
assert hasattr(
torch.optim.lr_scheduler, scheduler_name
), f"DeepSpeed does not recognize LR scheduler {scheduler_name}"
assert hasattr(torch.optim.lr_scheduler,
scheduler_name), f"DeepSpeed does not recognize LR scheduler {scheduler_name}"
scheduler = getattr(torch.optim.lr_scheduler, scheduler_name)
......@@ -965,9 +903,7 @@ class DeepSpeedEngine(Module):
return None
def _set_distributed_vars(self, args):
device_rank = args.device_rank if args is not None and hasattr(
args,
'device_rank') else self.local_rank
device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank
if device_rank >= 0:
get_accelerator().set_device(device_rank)
self.device = torch.device(get_accelerator().device_name(), device_rank)
......@@ -996,48 +932,23 @@ class DeepSpeedEngine(Module):
if hasattr(args, 'local_rank'):
args.local_rank = self.local_rank
if self.config is None:
self.config = (args.deepspeed_config
if hasattr(args,
"deepspeed_config") else None)
self._config = DeepSpeedConfig(self.config, mpu)
# Validate command line arguments
def _do_args_sanity_check(self, args):
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning(
"************ --deepscale_config is deprecated, please use --deepspeed_config ************"
)
if hasattr(args, "deepspeed_config"):
assert (
args.deepspeed_config is None
), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args.deepspeed_config = args.deepscale_config
assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \
"variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \
"different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed."
if hasattr(args, 'local_rank') and args.local_rank != None:
assert isinstance(
args.local_rank, int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}"
assert isinstance(args.local_rank,
int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}"
if args.local_rank >= 0:
env_local_rank = int(os.environ.get("LOCAL_RANK"))
assert (
env_local_rank == args.local_rank
), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}."
if self.config is None:
assert (
hasattr(
args, "deepspeed_config") and args.deepspeed_config is not None
), "DeepSpeed requires --deepspeed_config to specify configuration file"
def _is_supported_optimizer(self, optimizer_name):
return (optimizer_name in DEEPSPEED_OPTIMIZERS
or getattr(torch.optim,
optimizer_name,
None) is not None)
return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None)
def _supported_optims(self):
FairseqOptimizer = None
......@@ -1062,18 +973,11 @@ class DeepSpeedEngine(Module):
if not self.client_optimizer:
if self.optimizer_name() is not None:
assert self._is_supported_optimizer(
self.optimizer_name()
), "{} is not a supported DeepSpeed Optimizer".format(
self.optimizer_name()
)
self.optimizer_name()), "{} is not a supported DeepSpeed Optimizer".format(self.optimizer_name())
if (self.optimizer_name() == LAMB_OPTIMIZER
or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER):
assert (
self.dynamic_loss_scale()
), "DeepSpeed {} optimizer requires dynamic loss scaling".format(
self.optimizer_name()
)
if (self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER):
assert (self.dynamic_loss_scale()), "DeepSpeed {} optimizer requires dynamic loss scaling".format(
self.optimizer_name())
# Detect invalid combinations of client optimizer and client scheduler
if isinstance(self.client_lr_scheduler, _LRScheduler):
......@@ -1081,6 +985,7 @@ class DeepSpeedEngine(Module):
f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated'
def _broadcast_model(self):
def is_replicated(p):
if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE:
return False
......@@ -1095,20 +1000,15 @@ class DeepSpeedEngine(Module):
group=self.expert_data_parallel_group[p.group_name])
else:
if torch.is_tensor(p) and is_replicated(p):
dist.broadcast(p,
groups._get_broadcast_src_rank(),
group=self.data_parallel_group)
dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.data_parallel_group)
@staticmethod
def __check_params(model: Module, dtype: torch.dtype) -> None:
return
if not all(param.dtype == dtype
for param in model.parameters()) and dist.get_rank() == 0:
raise ValueError(
f"{dtype} is enabled but the following parameters have dtype that is "
f"not {dtype}: "
f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}"
)
if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0:
raise ValueError(f"{dtype} is enabled but the following parameters have dtype that is "
f"not {dtype}: "
f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}")
def _set_client_model(self, model):
# register client model in _modules so that nn.module methods work correctly
......@@ -1122,14 +1022,12 @@ class DeepSpeedEngine(Module):
if self.fp16_enabled():
if self.zero_optimization_partition_weights() and any(
[hasattr(param,
"ds_id") for param in self.module.parameters()]):
[hasattr(param, "ds_id") for param in self.module.parameters()]):
self.__check_params(self.module, torch.half)
self.module.half()
elif self.bfloat16_enabled():
if self.zero_optimization_partition_weights() and any(
hasattr(param,
'ds_id') for param in self.module.parameters()):
hasattr(param, 'ds_id') for param in self.module.parameters()):
self.__check_params(self.module, torch.bfloat16)
self.module.bfloat16()
else:
......@@ -1183,8 +1081,7 @@ class DeepSpeedEngine(Module):
return [id(param) for param in group]
occurrence = sum([
ids_list(group['params']).count(param_id)
if param_id in ids_list(group['params']) else 0
ids_list(group['params']).count(param_id) if param_id in ids_list(group['params']) else 0
for group in optimizer.param_groups
])
assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour."
......@@ -1204,9 +1101,7 @@ class DeepSpeedEngine(Module):
), 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
if self.global_rank == 0:
logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
logger.warning("**** You are using ZeRO with an untested optimizer, proceed with caution *****")
if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage(
) == 1:
......@@ -1214,23 +1109,19 @@ class DeepSpeedEngine(Module):
if model_dtype != grad_accum_dtype:
raise NotImplementedError(
"Model data type and gradient accumulation data type must be equal to use ZeRO"
)
"Model data type and gradient accumulation data type must be equal to use ZeRO")
return ZERO_OPTIMIZATION
elif amp_enabled:
if model_dtype != grad_accum_dtype:
raise NotImplementedError(
"Model data type and gradient accumulation data type must be equal to use Amp"
)
"Model data type and gradient accumulation data type must be equal to use Amp")
if model_dtype == torch.bfloat16 or model_dtype == torch.float16:
raise NotImplementedError(
"Cannot enable both amp with (legacy) fp16 or bfloat16 mode")
raise NotImplementedError("Cannot enable both amp with (legacy) fp16 or bfloat16 mode")
try:
logger.info("Initializing Apex amp from: {}".format(amp.__path__))
except NameError:
# If apex/amp is available it will be imported above
raise RuntimeError(
"Unable to import apex/amp, please make sure it is installed")
raise RuntimeError("Unable to import apex/amp, please make sure it is installed")
return AMP
# data type checks
elif model_dtype == grad_accum_dtype:
......@@ -1244,8 +1135,7 @@ class DeepSpeedEngine(Module):
elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32:
return BFLOAT16
else:
raise NotImplementedError(
"unsupported mix of model dtype and gradient accummulation type")
raise NotImplementedError("unsupported mix of model dtype and gradient accummulation type")
return None
......@@ -1256,27 +1146,26 @@ class DeepSpeedEngine(Module):
client_optimizer.param_groups[:] = [
pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0
]
log_dist(
"Removing param_group that has no 'params' in the client Optimizer",
ranks=[0])
log_dist("Removing param_group that has no 'params' in the client Optimizer", ranks=[0])
basic_optimizer = client_optimizer
log_dist('Using client Optimizer as basic optimizer', ranks=[0])
else:
basic_optimizer = client_optimizer(model_parameters)
log_dist('Using client callable to create basic optimizer', ranks=[0])
if self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam):
if self.zero_force_ds_cpu_optimizer():
msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.'
raise ZeRORuntimeException(msg)
else:
basic_optimizer = self._configure_basic_optimizer(model_parameters)
log_dist(
f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer",
ranks=[0])
log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0])
self._check_for_duplicates(basic_optimizer)
self.basic_optimizer = basic_optimizer
log_dist("DeepSpeed Basic Optimizer = {}".format(
basic_optimizer.__class__.__name__),
ranks=[0])
log_dist("DeepSpeed Basic Optimizer = {}".format(basic_optimizer.__class__.__name__), ranks=[0])
optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer)
......@@ -1285,9 +1174,7 @@ class DeepSpeedEngine(Module):
elif optimizer_wrapper == AMP:
amp_params = self.amp_params()
log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0])
model, self.optimizer = amp.initialize(
self.module, basic_optimizer, **amp_params
)
model, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
self._set_client_model(model)
self._broadcast_model()
# TODO: maybe need to broadcast experts differently?
......@@ -1298,8 +1185,7 @@ class DeepSpeedEngine(Module):
else:
self.optimizer = basic_optimizer
log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()),
ranks=[0])
log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), ranks=[0])
self.compression_scheduler = self._configure_compression_scheduler()
self.quantizer = self._configure_quantization()
......@@ -1314,32 +1200,24 @@ class DeepSpeedEngine(Module):
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
)
if self.optimizer_name() in [ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT)
# Optimizer name of Adam forces AdamW logic unless adam_w_mode is explicitly set
effective_adam_w_mode = self.optimizer_name(
) == ADAMW_OPTIMIZER or adam_w_mode
effective_adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER or adam_w_mode
if torch_adam:
if not effective_adam_w_mode:
optimizer = torch.optim.Adam(model_parameters,
**optimizer_parameters)
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
else:
optimizer = torch.optim.AdamW(model_parameters,
**optimizer_parameters)
optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters)
else:
if self.zero_use_cpu_optimizer():
if self.optimizer_name() == ADAGRAD_OPTIMIZER:
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
optimizer = DeepSpeedCPUAdagrad(model_parameters,
**optimizer_parameters)
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=effective_adam_w_mode)
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=effective_adam_w_mode)
else:
from deepspeed.ops.adam import FusedAdam
......@@ -1349,6 +1227,12 @@ class DeepSpeedEngine(Module):
adam_w_mode=effective_adam_w_mode,
)
elif self.optimizer_name() == ADAGRAD_OPTIMIZER:
if self.zero_use_cpu_optimizer():
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters)
else:
optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb
......@@ -1359,26 +1243,21 @@ class DeepSpeedEngine(Module):
optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled():
logger.warning(
f"Currently the convergence of 1-bit Adam is only verified under FP16"
)
logger.warning(f"Currently the convergence of 1-bit Adam is only verified under FP16")
elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER:
assert not self.zero_optimization(), "0/1 Adam is not compatible with ZeRO"
from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled():
logger.warning(
f'Currently the convergence of 0/1 Adam is only verified under FP16')
logger.warning(f'Currently the convergence of 0/1 Adam is only verified under FP16')
elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER:
assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO"
from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb
optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled():
logger.warning(
f"Currently the convergence of 1-bit Lamb is only verified under FP16"
)
logger.warning(f"Currently the convergence of 1-bit Lamb is only verified under FP16")
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
......@@ -1403,7 +1282,8 @@ class DeepSpeedEngine(Module):
use_quantizer_kernel,
) = self.quantize_training()
if quantize_enabled and not quantize_weight_in_forward:
assert self.fp16_enabled(), "MoQ (quantize in optimization step) weight quantization is only supported for FP16"
assert self.fp16_enabled(
), "MoQ (quantize in optimization step) weight quantization is only supported for FP16"
quantizer = None
if quantize_enabled and not quantize_weight_in_forward:
from deepspeed.runtime.quantize import Quantizer
......@@ -1447,9 +1327,7 @@ class DeepSpeedEngine(Module):
has_moe_layers=self.has_moe_layers,
)
else:
log_dist(
f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}',
ranks=[0])
log_dist(f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}', ranks=[0])
optimizer = FP16_Optimizer(
optimizer,
deepspeed=self,
......@@ -1460,8 +1338,7 @@ class DeepSpeedEngine(Module):
has_moe_layers=self.has_moe_layers,
)
else:
log_dist(f'Creating fp16 unfused optimizer with dynamic loss scale',
ranks=[0])
log_dist(f'Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0])
optimizer = FP16_UnfusedOptimizer(
optimizer,
deepspeed=self,
......@@ -1484,19 +1361,20 @@ class DeepSpeedEngine(Module):
log_dist('Creating BF16 optimizer', ranks=[0])
timers = self.timers if self.wall_clock_breakdown() else None
optimizer = BF16_Optimizer(
optimizer,
self.param_names,
mpu=self.mpu,
clip_grad=clip_grad,
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group,
timers=timers)
optimizer = BF16_Optimizer(optimizer,
self.param_names,
mpu=self.mpu,
clip_grad=clip_grad,
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group,
timers=timers)
return optimizer
def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage()
mics_shard_size = self.mics_shard_size()
model_dtype, grad_accum_dtype = self.get_data_types()
timers = self.timers if self.wall_clock_breakdown() else None
......@@ -1514,8 +1392,7 @@ class DeepSpeedEngine(Module):
round_robin_gradients = self.zero_round_robin_gradients()
assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage)
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer',
ranks=[0])
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
# Overlap and contiguous grads are meaningless in stage 1 and are ignored
if zero_stage == ZeroStageEnum.optimizer_states:
overlap_comm = False
......@@ -1526,9 +1403,7 @@ class DeepSpeedEngine(Module):
if isinstance(self.module, PipelineModule):
if overlap_comm:
logger.warning(
"Pipeline parallelism does not support overlapped communication, will be disabled."
)
logger.warning("Pipeline parallelism does not support overlapped communication, will be disabled.")
overlap_comm = False
optimizer = DeepSpeedZeroOptimizer(
optimizer,
......@@ -1542,10 +1417,8 @@ class DeepSpeedEngine(Module):
reduce_bucket_size=self.zero_reduce_bucket_size(),
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group,
expert_parallel_group=self.expert_parallel_group
if self.has_moe_layers else None,
expert_data_parallel_group=self.expert_data_parallel_group
if self.has_moe_layers else None,
expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None,
expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=overlap_comm,
cpu_offload=self.zero_cpu_offload(),
......@@ -1557,8 +1430,7 @@ class DeepSpeedEngine(Module):
partition_grads=zero_stage == ZeroStageEnum.gradients,
round_robin_gradients=round_robin_gradients,
has_moe_layers=self.has_moe_layers,
fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(
),
fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(),
communication_data_type=self.communication_data_type,
elastic_checkpoint=self.zero_elastic_checkpoint())
......@@ -1566,21 +1438,27 @@ class DeepSpeedEngine(Module):
assert not self.has_moe_layers, "MoE not supported with Stage 3"
if isinstance(optimizer, DummyOptim):
log_dist("Creating ZeRO Offload", ranks=[0])
optimizer = DeepSpeedZeRoOffload(
self.module,
timers=timers,
ds_config=self.config,
overlap_comm=self.zero_overlap_comm(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
offload_param_config=self.zero_offload_param(),
mpu=self.mpu)
optimizer = DeepSpeedZeRoOffload(self.module,
timers=timers,
ds_config=self.config,
overlap_comm=self.zero_overlap_comm(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
offload_param_config=self.zero_offload_param(),
mpu=self.mpu)
else:
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer',
ranks=[0])
log_dist(
f'Creating fp16 ZeRO stage {zero_stage} optimizer,'
f' MiCS is enabled {mics_shard_size>0},'
f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}',
ranks=[0])
if mics_shard_size > 0:
return self._return_mics_optimizer(optimizer, timers)
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
optimizer = DeepSpeedZeroOptimizer_Stage3(
self.module,
......@@ -1616,6 +1494,37 @@ class DeepSpeedEngine(Module):
return optimizer
def _return_mics_optimizer(self, basic_optimizer, timers):
from deepspeed.runtime.zero.mics import MiCS_Optimizer
optimizer = MiCS_Optimizer(self.module,
basic_optimizer,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
sub_group_size=self.zero_sub_group_size(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
aio_config=self.aio_config(),
communication_data_type=self.communication_data_type)
return optimizer
def _configure_eigenvalue(self):
eigenvalue = Eigenvalue(
verbose=self.eigenvalue_verbose(),
......@@ -1644,9 +1553,7 @@ class DeepSpeedEngine(Module):
@staticmethod
def is_iterable_style_dataset(obj):
return isinstance(obj,
torch.utils.data.IterableDataset
) # hasattr(obj, "__iter__") should work as well
return isinstance(obj, torch.utils.data.IterableDataset) # hasattr(obj, "__iter__") should work as well
def dataloader_drop_last(self):
return self._config.dataloader_drop_last
......@@ -1669,8 +1576,7 @@ class DeepSpeedEngine(Module):
data_sampler=None,
collate_fn=None,
num_local_io_workers=None):
if not (self.is_map_style_dataset(dataset)
or self.is_iterable_style_dataset(dataset)):
if not (self.is_map_style_dataset(dataset) or self.is_iterable_style_dataset(dataset)):
raise ValueError("Training data must be a torch Dataset")
if batch_size is None:
......@@ -1702,33 +1608,26 @@ class DeepSpeedEngine(Module):
deepspeed_dataloader_config = {}
if self.curriculum_learning_enabled():
deepspeed_dataloader_config = {
CURRICULUM_LEARNING:
self.curriculum_learning_enabled(),
DATA_EFFICIENCY:
self.data_efficiency_config(),
DATA_PARALLEL_GROUP:
self.data_parallel_group,
GRADIENT_ACCUMULATION_STEPS:
self.gradient_accumulation_steps(),
GLOBAL_RANK:
self.global_rank,
DATA_SAMPLING_NUM_WORKERS:
self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]
CURRICULUM_LEARNING: self.curriculum_learning_enabled(),
DATA_EFFICIENCY: self.data_efficiency_config(),
DATA_PARALLEL_GROUP: self.data_parallel_group,
GRADIENT_ACCUMULATION_STEPS: self.gradient_accumulation_steps(),
GLOBAL_RANK: self.global_rank,
DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]
}
return DeepSpeedDataLoader(
dataset=dataset,
batch_size=batch_size,
pin_memory=pin_memory,
collate_fn=collate_fn,
local_rank=self.local_rank,
tput_timer=deepspeed_io_timer,
num_local_io_workers=num_local_io_workers,
data_sampler=data_sampler,
data_parallel_world_size=data_parallel_world_size,
data_parallel_rank=data_parallel_rank,
dataloader_drop_last=self.dataloader_drop_last(),
deepspeed_dataloader_config=deepspeed_dataloader_config)
return DeepSpeedDataLoader(dataset=dataset,
batch_size=batch_size,
pin_memory=pin_memory,
collate_fn=collate_fn,
local_rank=self.local_rank,
tput_timer=deepspeed_io_timer,
num_local_io_workers=num_local_io_workers,
data_sampler=data_sampler,
data_parallel_world_size=data_parallel_world_size,
data_parallel_rank=data_parallel_rank,
dataloader_drop_last=self.dataloader_drop_last(),
deepspeed_dataloader_config=deepspeed_dataloader_config)
def train(self, mode=True):
r""""""
......@@ -1755,9 +1654,7 @@ class DeepSpeedEngine(Module):
else:
scaled_loss = prescaled_loss
if self.warn_unscaled_loss:
logger.warning(
f"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}"
)
logger.warning(f"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}")
self.warn_unscaled_loss = False
return scaled_loss
......@@ -1775,9 +1672,8 @@ class DeepSpeedEngine(Module):
else:
see_memory_usage("Engine before forward", force=self.memory_breakdown())
flops_profiler_active = (self.flops_profiler_enabled() and self.global_steps
== self.flops_profiler_profile_step()
and self.global_rank == 0)
flops_profiler_active = (self.flops_profiler_enabled()
and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0)
# used to check quantization happens at step 0!
if self.global_steps == 0 and hasattr(self, "compression_scheduler"):
......@@ -1806,10 +1702,7 @@ class DeepSpeedEngine(Module):
if self.module.training and self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
kwargs.update({
"curriculum_seqlen":
self.curriculum_scheduler_legacy.get_current_difficulty()
})
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})
if self.module.training and self.random_ltd_enabled():
self.random_ltd_scheduler.update_seq(self.global_steps)
......@@ -1819,7 +1712,6 @@ class DeepSpeedEngine(Module):
# we are in a forward pass.
for module in self.module.modules():
module._parameters._in_forward = True
pass
self._start_timers(self.engine_timers.forward_timers)
......@@ -1844,9 +1736,7 @@ class DeepSpeedEngine(Module):
if self.autotuning_profile_model_info():
activation_mem = get_ma_status() - ma
self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem
print_json_dist(self.autotuning_model_info,
[0],
path=self.autotuning_model_info_path())
print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path())
exit()
else:
see_memory_usage("Engine after forward", force=self.memory_breakdown())
......@@ -1897,27 +1787,21 @@ class DeepSpeedEngine(Module):
f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled'
# Pass (PP) gas boundary flag to optimizer (required for zero)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
# Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary():
if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states:
self.optimizer.reduce_gradients(
pipeline_parallel=self.pipeline_parallelism)
if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states and hasattr(
self.optimizer, 'reduce_gradients'):
self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
@instrument_w_nvtx
def backward(self,
loss,
allreduce_gradients=True,
release_loss=False,
retain_graph=False,
scale_wrt_gas=True):
def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
......@@ -1932,9 +1816,7 @@ class DeepSpeedEngine(Module):
scale_wrt_gas = self.scale_wrt_gas
if not allreduce_gradients:
logger.warning(
f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed"
)
logger.warning(f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed")
# scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
......@@ -1959,16 +1841,13 @@ class DeepSpeedEngine(Module):
self._start_timers(self.engine_timers.backward_inner_timers)
if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
self.optimizer.backward(loss, retain_graph=retain_graph)
elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not self.is_gradient_accumulation_boundary()
with amp.scale_loss(loss,
self.optimizer,
delay_unscale=delay_unscale) as scaled_loss:
with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward(retain_graph=retain_graph)
elif self.fp16_enabled():
if self.eigenvalue_enabled():
......@@ -2051,22 +1930,17 @@ class DeepSpeedEngine(Module):
param.grad = None
def clip_fp32_gradients(self):
clip_grad_norm_(parameters=self.module.parameters(),
max_norm=self.gradient_clipping(),
mpu=self.mpu)
clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu)
def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
if self.gradient_clipping() > 0.0:
if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled()
or self.zero_optimization()):
if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()):
self.clip_fp32_gradients()
elif self.amp_enabled():
# AMP's recommended way of doing clipping
# https://nvidia.github.io/apex/advanced.html#gradient-clipping
master_params = amp.master_params(self.optimizer)
clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping(),
mpu=self.mpu)
clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)
self.optimizer.step()
if hasattr(self.optimizer, '_global_grad_norm'):
......@@ -2087,7 +1961,7 @@ class DeepSpeedEngine(Module):
# the behaviour that we want
if self.bfloat16_enabled():
# TODO: Temporary until bf16_optimizer and zero_optimizer are integrated
if self.zero_optimization():
if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"):
self.optimizer.zero_grad()
else:
pass
......@@ -2132,8 +2006,7 @@ class DeepSpeedEngine(Module):
# Check early because self.global_steps is incremented at some point here.
# TODO: Delay self.global_steps increment until very end of this function.
flops_profiler_active = self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_profile_step(
) and self.global_rank == 0
) and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0
self._start_timers(self.engine_timers.step_timers)
......@@ -2148,20 +2021,16 @@ class DeepSpeedEngine(Module):
if self.is_gradient_accumulation_boundary():
self.gas_boundary_ctr += 1
if (self.eigenvalue_enabled() and
(self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0)
if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0)
and self.quantizer.any_precision_switch()):
log_dist(f"computing eigenvalue...", ranks=[0])
self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(
self.module,
self.device,
self.optimizer.cur_scale)
self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device,
self.optimizer.cur_scale)
if self.progressive_layer_drop:
self.progressive_layer_drop.update_state(self.global_steps)
if (self.eigenvalue_enabled() and not self.gas_boundary_ctr %
self.eigenvalue_gas_boundary_resolution()
if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()
and self.quantizer.any_precision_switch()):
self._take_model_step(lr_kwargs, self.block_eigenvalue)
else:
......@@ -2169,8 +2038,7 @@ class DeepSpeedEngine(Module):
report_progress = self.global_rank == 0 if self.global_rank else True
self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(),
report_speed=report_progress)
self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress)
self._stop_timers(self.engine_timers.step_timers)
......@@ -2178,9 +2046,7 @@ class DeepSpeedEngine(Module):
if self.monitor.enabled:
if self.is_gradient_accumulation_boundary():
if self.global_rank == 0:
self.summary_events = [(f"Train/Samples/lr",
self.get_lr()[0],
self.global_samples)]
self.summary_events = [(f"Train/Samples/lr", self.get_lr()[0], self.global_samples)]
if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"):
self.summary_events.append((
......@@ -2189,8 +2055,8 @@ class DeepSpeedEngine(Module):
self.global_samples,
))
if (self.eigenvalue_enabled() and not self.gas_boundary_ctr %
self.eigenvalue_gas_boundary_resolution()):
if (self.eigenvalue_enabled()
and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()):
ev_values = self.block_eigenvalue.values()
for i in range(len(ev_values)):
self.summary_events.append((
......@@ -2214,14 +2080,12 @@ class DeepSpeedEngine(Module):
)
self.flops_profiler.end_profile()
if self.autotuning_enabled() and self.global_steps == (
self.autotuning_end_profile_step() + 1):
if self.autotuning_enabled() and self.global_steps == (self.autotuning_end_profile_step() + 1):
self._autotuning_exit()
if self.wall_clock_breakdown():
# Log micro timing and reset
self.timers.log(names=self.engine_timers.micro_timers,
memory_breakdown=self.memory_breakdown())
self.timers.log(names=self.engine_timers.micro_timers, memory_breakdown=self.memory_breakdown())
if self.wall_clock_breakdown() or self.flops_profiler_enabled():
# Log global timing and reset
......@@ -2255,13 +2119,10 @@ class DeepSpeedEngine(Module):
FORWARD_GLOBAL_TIMER,
BACKWARD_GLOBAL_TIMER,
STEP_GLOBAL_TIMER,
],
reset=False)
titer = msg[FORWARD_GLOBAL_TIMER] + msg[BACKWARD_GLOBAL_TIMER] + msg[
STEP_GLOBAL_TIMER]
], reset=False)
titer = msg[FORWARD_GLOBAL_TIMER] + msg[BACKWARD_GLOBAL_TIMER] + msg[STEP_GLOBAL_TIMER]
msg["latency"] = titer
msg["FLOPS_per_gpu"] = self.flops * 1_000_000 * self.gradient_accumulation_steps(
) / titer
msg["FLOPS_per_gpu"] = self.flops * 1_000_000 * self.gradient_accumulation_steps() / titer
msg["throughput"] = self.train_batch_size() * 1_000_000 / \
msg["latency"]
print_json_dist(msg, [0], path=self.autotuning_metric_path())
......@@ -2335,8 +2196,7 @@ class DeepSpeedEngine(Module):
def _report_progress(self, step):
lr = self.get_lr()
mom = self.get_mom()
log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}",
ranks=[0])
log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0])
def allreduce_bucket(self, bucket, dp_group):
tensor = self.flatten(bucket)
......@@ -2352,10 +2212,8 @@ class DeepSpeedEngine(Module):
dist.all_reduce(tensor_to_allreduce, group=dp_group)
if self.gradient_average:
if self.gradient_predivide_factor() != dist.get_world_size(
group=dp_group):
tensor_to_allreduce.mul_(self.gradient_predivide_factor() /
dist.get_world_size(group=dp_group))
if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group):
tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))
else:
tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group))
dist.all_reduce(tensor_to_allreduce, group=dp_group)
......@@ -2397,9 +2255,7 @@ class DeepSpeedEngine(Module):
# rank is reducing the same size. In some cases it may make
# sense in the future to support the ability to average not
# w.r.t. world size but with a different value.
param.grad = torch.zeros(param.size(),
dtype=param.dtype,
device=param.device)
param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device)
grad_data = param.grad.data
if param_name in self.sparse_tensor_module_names or grad_data.is_sparse:
......@@ -2426,9 +2282,7 @@ class DeepSpeedEngine(Module):
if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(bucket, dp_group=dp_group)
else:
self.allreduce_no_retain(bucket,
dp_group=dp_group,
numel_per_bucket=elements_per_buffer)
self.allreduce_no_retain(bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer)
def _reduce_expert_gradients(self, expert_grads, elements_per_buffer):
for ep_name, expert_grads_group in expert_grads.items():
......@@ -2436,15 +2290,12 @@ class DeepSpeedEngine(Module):
for i, bucket_tuple in enumerate(expert_split_buckets):
bucket_type, bucket = bucket_tuple
if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(
bucket,
groups._get_expert_data_parallel_group(ep_name))
self.sparse_allreduce_no_retain(bucket, groups._get_expert_data_parallel_group(ep_name))
else:
# Separate between diff groups
self.allreduce_no_retain(
bucket,
dp_group=groups._get_expert_data_parallel_group(ep_name),
numel_per_bucket=elements_per_buffer)
self.allreduce_no_retain(bucket,
dp_group=groups._get_expert_data_parallel_group(ep_name),
numel_per_bucket=elements_per_buffer)
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
if grads is None:
......@@ -2487,8 +2338,7 @@ class DeepSpeedEngine(Module):
if self.postscale_gradients():
if self.gradient_average:
values.mul_(self.gradient_predivide_factor() /
dist.get_world_size(group=dp_group))
values.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))
else:
values.mul_(1. / dist.get_world_size(group=dp_group))
......@@ -2509,36 +2359,25 @@ class DeepSpeedEngine(Module):
if value.dim() == 1:
if fill_size > 0:
value = torch.cat([value, value.new_empty(fill_size)])
tensor_list = [
value.new_empty(max_size)
for _ in range(dist.get_world_size(group=dp_group))
]
tensor_list = [value.new_empty(max_size) for _ in range(dist.get_world_size(group=dp_group))]
else:
if fill_size > 0:
value = torch.cat([value, value.new_empty(fill_size, value.size()[1])])
tensor_list = [
value.new_empty(max_size,
value.size()[1])
for _ in range(dist.get_world_size(group=dp_group))
value.size()[1]) for _ in range(dist.get_world_size(group=dp_group))
]
dist.all_gather(tensor_list, value, group=dp_group)
tensors = []
for dev_idx, t in enumerate(tensor_list):
size = all_sizes[dev_idx][0]
tensors.append(
t.index_select(0,
torch.arange(size,
dtype=torch.long,
device=self.device)))
tensors.append(t.index_select(0, torch.arange(size, dtype=torch.long, device=self.device)))
return tensors
def all_gather_scalar(self, value, dp_group):
tensor_list = [
value.new_zeros(value.size())
for _ in range(dist.get_world_size(group=dp_group))
]
tensor_list = [value.new_zeros(value.size()) for _ in range(dist.get_world_size(group=dp_group))]
dist.all_gather(tensor_list, value, group=dp_group)
return tensor_list
......@@ -2558,20 +2397,19 @@ class DeepSpeedEngine(Module):
num_experts=1,
checkpoint_engine=TorchCheckpointEngine()):
if old_moe_load:
expp_rank = groups._get_expert_data_parallel_rank(
groups._get_max_expert_size_name())
expp_rank = groups._get_expert_data_parallel_rank(groups._get_max_expert_size_name())
num_local_experts = max(
num_experts) // groups._get_expert_parallel_world_size(
groups._get_max_expert_size_name())
num_local_experts = max(num_experts) // groups._get_expert_parallel_world_size(
groups._get_max_expert_size_name())
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name(
checkpoint_path,
-1, # -1 means ignore layer_id
global_expert_id,
tag,
mpu),
expert_state_dict = checkpoint_engine.load(
DeepSpeedEngine._get_expert_ckpt_name(
checkpoint_path,
-1, # -1 means ignore layer_id
global_expert_id,
tag,
mpu),
map_location=torch.device('cpu'))
# Updating global -> local expert ids
......@@ -2592,41 +2430,45 @@ class DeepSpeedEngine(Module):
# loop all local_experts
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
expert_state_dict = checkpoint_engine.load(
DeepSpeedEngine._get_expert_ckpt_name(
checkpoint_path,
moe_layer_id,
global_expert_id,
tag,
mpu),
map_location=torch.device('cpu'))
expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name(
checkpoint_path, moe_layer_id, global_expert_id, tag, mpu),
map_location=torch.device('cpu'))
# print(expert_state_dict.keys())
# Updating global -> local expert ids
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
for key in list(expert_state_dict.keys()):
local_key = key.replace(
f'{moe_str_prefix}{global_expert_id}',
f'{moe_str_prefix}{local_expert_id}')
local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
f'{moe_str_prefix}{local_expert_id}')
expert_state_dict[local_key] = expert_state_dict.pop(key)
state_dict.update(expert_state_dict)
moe_layer_id += 1
def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None):
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None):
module_state_dict = checkpoint['module']
if custom_load_fn:
custom_load_fn(src=state_dict, dst=self.module)
custom_load_fn(src=module_state_dict, dst=self.module)
else:
self.module.load_state_dict(state_dict, # TODO
strict=strict)
self.module.load_state_dict(
module_state_dict, # TODO
strict=strict)
if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None:
saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS]
for param in self.module.parameters():
if param.requires_grad:
continue
if param not in self.param_names:
raise ValueError(f"failed to find frozen {param} in named params")
name = self.param_names[param]
if hasattr(param, 'ds_id'):
param.ds_tensor.data.copy_(saved_frozen_params[name].data)
else:
param.data.copy_(saved_frozen_params[name].data)
def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode):
return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}'
def _get_rank_zero_ckpt_name(self,
checkpoints_path,
tag,
mp_rank,
dp_rank,
bf16_mode):
def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank, bf16_mode):
file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode)
zero_ckpt_name = os.path.join(
checkpoints_path,
......@@ -2639,11 +2481,7 @@ class DeepSpeedEngine(Module):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
pp_rank = dist.get_rank(group=self.optimizer.dp_process_group)
bf16_mode = self.bfloat16_enabled()
return self._get_rank_zero_ckpt_name(checkpoints_path,
tag,
mp_rank,
pp_rank,
bf16_mode)
return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode)
def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
if mp_placeholder is not None:
......@@ -2653,8 +2491,7 @@ class DeepSpeedEngine(Module):
mp_rank_str = f"{mp_rank:02d}"
if self.zero_optimization_partition_weights():
filename = "zero_pp_rank_{}".format(
dist.get_rank(group=self.optimizer.dp_process_group))
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
......@@ -2670,10 +2507,8 @@ class DeepSpeedEngine(Module):
def _get_optimizer_ckpt_name(self, checkpoints_path, tag, expp_rank):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt')
ckpt_name = os.path.join(checkpoints_path, str(tag),
f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt')
return ckpt_name
@staticmethod
......@@ -2681,24 +2516,17 @@ class DeepSpeedEngine(Module):
mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank()
if layer_id <= -1:
# Used to support old checkpoint loading
ckpt_name = os.path.join(
checkpoints_path,
'' if tag is None else str(tag),
f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag),
f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
else:
# Used to support new checkpoint loading
ckpt_name = os.path.join(
checkpoints_path,
'' if tag is None else str(tag),
f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt'
)
ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag),
f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
return ckpt_name
def _get_all_ckpt_names(self, checkpoints_path, tag):
# It is required that (checkpoints_path, tag) are consistent among all ranks.
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
tag,
mp_placeholder="*")
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
import glob
ckpt_files = glob.glob(ckpt_file_pattern)
......@@ -2738,17 +2566,14 @@ class DeepSpeedEngine(Module):
"""
if tag is None:
latest_tag = "latest_universal" if self.load_universal_checkpoint(
) else "latest"
latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest"
latest_path = os.path.join(load_dir, latest_tag)
if os.path.isfile(latest_path):
with open(latest_path, "r") as fd:
tag = fd.read().strip()
else:
if self.load_universal_checkpoint():
raise ValueError(
f'Invalid for universal checkpoint: {latest_path} does not exist'
)
raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist')
else:
logger.warning(
f"Unable to find latest file at {latest_path}, if trying to load latest "
......@@ -2770,10 +2595,7 @@ class DeepSpeedEngine(Module):
load_zero_checkpoint = self.zero_optimization() or self.bfloat16_enabled()
if load_zero_checkpoint and load_path is not None:
success = self._load_zero_checkpoint(
load_dir,
tag,
load_optimizer_states=load_optimizer_states)
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
if not success:
self.optimizer._restore_from_bit16_weights()
......@@ -2794,16 +2616,12 @@ class DeepSpeedEngine(Module):
from deepspeed.runtime.state_dict_factory import SDLoaderFactory
ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader(
ckpt_list,
checkpoint_engine=self.checkpoint_engine)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine)
is_pipe_parallel = isinstance(self.module, PipelineModule)
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
load_path, checkpoint, _ = sd_loader.load(
self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel
)
load_path, checkpoint, _ = sd_loader.load(self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel)
if checkpoint is None:
return None, None
......@@ -2826,7 +2644,7 @@ class DeepSpeedEngine(Module):
num_experts=self.num_experts,
checkpoint_engine=self.checkpoint_engine)
if not self.load_universal_checkpoint():
self.load_module_state_dict(state_dict=checkpoint['module'],
self.load_module_state_dict(checkpoint=checkpoint,
strict=load_module_strict,
custom_load_fn=custom_load_fn)
......@@ -2841,38 +2659,29 @@ class DeepSpeedEngine(Module):
largest_group_name = groups._get_max_expert_size_name()
expp_rank = groups._get_expert_parallel_rank(largest_group_name)
optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank)
optim_checkpoint = self.checkpoint_engine.load(
optim_load_path,
map_location=torch.device('cpu'))
optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu'))
else:
optim_checkpoint = checkpoint
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled(
)
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state:
if self.fp16_enabled():
self.optimizer.load_state_dict(
optim_checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
self.optimizer.load_state_dict(optim_checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
else:
self.optimizer.load_state_dict(optim_checkpoint['optimizer'])
if load_lr_scheduler_states and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
if self.random_ltd_enabled(
) and self.random_ltd_scheduler is not None and 'random_ltd' in checkpoint:
if self.random_ltd_enabled() and self.random_ltd_scheduler is not None and 'random_ltd' in checkpoint:
self.random_ltd_scheduler.load_state_dict(checkpoint['random_ltd'])
if self.training_dataloader is not None and self.curriculum_learning_enabled(
) and 'data_sampler' in checkpoint:
self.training_dataloader.data_sampler.load_state_dict(
checkpoint['data_sampler'])
self.training_dataloader.data_sampler.load_state_dict(checkpoint['data_sampler'])
def get_sparse_tensor_module_names(original_set,
loaded_set,
original_parameters,
loaded_parameters):
def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters, loaded_parameters):
result = set()
for name in original_set:
......@@ -2882,8 +2691,7 @@ class DeepSpeedEngine(Module):
for name in loaded_set:
if name in original_parameters:
result.add(
name) # parameter exists in both configs and it was sparse
result.add(name) # parameter exists in both configs and it was sparse
return result
......@@ -2898,26 +2706,16 @@ class DeepSpeedEngine(Module):
self.sparse_tensor_module_names = sparse_tensor_module_names
else:
self.sparse_tensor_module_names = get_sparse_tensor_module_names(
self.sparse_tensor_module_names,
sparse_tensor_module_names,
dict(self.module.named_parameters()),
checkpoint["module"])
self.sparse_tensor_module_names, sparse_tensor_module_names,
dict(self.module.named_parameters()), checkpoint["module"])
self.global_steps = checkpoint['global_steps']
self.global_samples = checkpoint.get(
'global_samples',
self.global_steps * self.train_batch_size())
self.global_samples = checkpoint.get('global_samples', self.global_steps * self.train_batch_size())
self.skipped_steps = checkpoint['skipped_steps']
self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
deepspeed_states = [
'module',
'sparse_tensor_module_names',
'skipped_steps',
'global_steps',
'dp_world_size',
'mp_world_size',
'data_sampler',
'random_ltd'
'module', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'dp_world_size',
'mp_world_size', 'data_sampler', 'random_ltd'
]
client_state = {}
......@@ -2926,11 +2724,7 @@ class DeepSpeedEngine(Module):
if load_optimizer_states:
deepspeed_states.append('optimizer')
client_state = {
key: value
for key,
value in checkpoint.items() if not key in deepspeed_states
}
client_state = {key: value for key, value in checkpoint.items() if not key in deepspeed_states}
if not load_optimizer_states and not load_module_only:
client_state['optimizer'] = optim_checkpoint['optimizer']
......@@ -2953,28 +2747,18 @@ class DeepSpeedEngine(Module):
if zero_sd_list is None:
return False
self.optimizer.load_state_dict(
state_dict_list=zero_sd_list,
load_optimizer_states=load_optimizer_states,
load_from_fp32_weights=self.zero_load_from_fp32_weights(),
checkpoint_folder=checkpoint_folder)
self.optimizer.load_state_dict(state_dict_list=zero_sd_list,
load_optimizer_states=load_optimizer_states,
load_from_fp32_weights=self.zero_load_from_fp32_weights(),
checkpoint_folder=checkpoint_folder)
if self.load_universal_checkpoint():
logger.info(
f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}'
)
logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}')
else:
logger.info(
f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}"
)
logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}")
return True
def _get_mp_rank_zero_checkpoint_names(self,
load_dir,
tag,
mp_rank,
dp_world_size,
bf16_mode):
def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
zero_ckpt_names = []
for dp_rank in range(dp_world_size):
ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir,
......@@ -2988,18 +2772,16 @@ class DeepSpeedEngine(Module):
def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
load_dir=load_dir,
tag=tag,
mp_rank=mp_rank,
dp_world_size=self.loaded_checkpoint_dp_world_size,
bf16_mode=bf16_mode)
zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(load_dir=load_dir,
tag=tag,
mp_rank=mp_rank,
dp_world_size=self.loaded_checkpoint_dp_world_size,
bf16_mode=bf16_mode)
for i, ckpt_name in enumerate(zero_ckpt_names):
if not os.path.exists(ckpt_name):
# transparently handle the old file pattern for optim_states
if "optim_states.pt" in ckpt_name:
ckpt_name_try = ckpt_name.replace("_optim_states.pt",
"optim_states.pt")
ckpt_name_try = ckpt_name.replace("_optim_states.pt", "optim_states.pt")
if os.path.exists(ckpt_name_try):
zero_ckpt_names[i] = ckpt_name_try
continue
......@@ -3013,8 +2795,7 @@ class DeepSpeedEngine(Module):
if ckpt_name is None:
_state = {OPTIMIZER_STATE_DICT: None}
# Fully load state for current rank
elif self.zero_elastic_checkpoint() or dist.get_rank(
group=self.optimizer.dp_process_group) == i:
elif self.zero_elastic_checkpoint() or dist.get_rank(group=self.optimizer.dp_process_group) == i:
_state = self.checkpoint_engine.load(
ckpt_name,
map_location='cpu',
......@@ -3024,25 +2805,18 @@ class DeepSpeedEngine(Module):
zero_sd_list.append(_state)
zero_optimizer_sd = [sd[OPTIMIZER_STATE_DICT] for sd in zero_sd_list]
logger.info(
f"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}"
)
logger.info(f"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}")
return zero_optimizer_sd
def _get_all_zero_checkpoints(self, load_dir, tag):
for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]:
zero_ckpt_names = self._get_all_zero_checkpoint_names(
load_dir,
tag,
bf16_mode)
zero_ckpt_names = self._get_all_zero_checkpoint_names(load_dir, tag, bf16_mode)
if zero_ckpt_names is not None:
# Warn if loading checkpoint of different bit16 type
if bf16_mode is not self.bfloat16_enabled():
checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16
engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16
logger.warn(
f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine'
)
logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)
return None
......@@ -3056,10 +2830,9 @@ class DeepSpeedEngine(Module):
dist.all_reduce(max_bhash, op=dist.ReduceOp.MAX)
dist.all_reduce(min_bhash, op=dist.ReduceOp.MIN)
valid = all(min_bhash == bhash) and all(max_bhash == bhash)
msg = (
f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across "
"all ranks. Including rank unique information in checkpoint tag could cause issues when "
"restoring with different world sizes.")
msg = (f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across "
"all ranks. Including rank unique information in checkpoint tag could cause issues when "
"restoring with different world sizes.")
if self.checkpoint_tag_validation_fail():
assert valid, msg
elif not valid:
......@@ -3090,7 +2863,7 @@ class DeepSpeedEngine(Module):
# There seems to be issue creating them in parallel
# Ensure save_dir directory exists
os.makedirs(save_dir, exist_ok=True)
self.checkpoint_engine.makedirs(save_dir, exist_ok=True)
dist.barrier()
if tag is None:
......@@ -3191,15 +2964,9 @@ class DeepSpeedEngine(Module):
# let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items():
# save the moe parameters
moe_save_path = self._get_expert_ckpt_name(
save_dir,
moe_layer_id,
global_expert_id,
tag,
self.mpu)
moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
if self.random_ltd_enabled():
expert_state_dict = remove_random_ltd_state_dict(
expert_state_dict)
expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
self.checkpoint_engine.save(expert_state_dict, moe_save_path)
moe_layer_id += 1
......@@ -3217,9 +2984,7 @@ class DeepSpeedEngine(Module):
# Save optimizer states. They are different across each exp parallel rank.
optimizer_state = {
'optimizer':
self.optimizer.state_dict()
if self.optimizer and not self.zero_optimization() else None
'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None
}
# TODO: why use BufferedWriter not the path
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
......@@ -3234,15 +2999,12 @@ class DeepSpeedEngine(Module):
'module':
model_state_dict,
'lr_scheduler':
self.lr_scheduler.state_dict()
if self.lr_scheduler is not None else None,
self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
'data_sampler':
self.training_dataloader.data_sampler.state_dict() if
(self.training_dataloader is not None
and self.curriculum_learning_enabled()) else None,
(self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
'random_ltd':
self.random_ltd_scheduler.state_dict()
if self.random_ltd_enabled() else None,
self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,
'sparse_tensor_module_names':
self.sparse_tensor_module_names,
'skipped_steps':
......@@ -3264,11 +3026,11 @@ class DeepSpeedEngine(Module):
self._curr_save_path = None
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
name_function = (self._get_zero_ckpt_name
if zero_checkpoint else self._get_ckpt_name)
name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name)
try:
checkpoint_name = name_function(save_dir, tag)
ensure_directory_exists(checkpoint_name)
path = os.path.dirname(checkpoint_name)
self.checkpoint_engine.makedirs(path, exist_ok=True)
except:
logger.error(f"Failed saving model checkpoint to {save_dir} with tag {tag}")
return False
......@@ -3292,6 +3054,8 @@ class DeepSpeedEngine(Module):
zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
save_frozen_param = self.zero_optimization_partition_gradients()
# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns None. The module_state_dict() implementation in
......@@ -3302,17 +3066,17 @@ class DeepSpeedEngine(Module):
state = dict(module=module,
buffer_names=self._get_buffer_names(),
optimizer=self.optimizer.state_dict()
if self.optimizer and not zero_optimizer_state else None,
param_shapes=self._get_zero_param_shapes()
if self.optimizer and zero_optimizer_state else None,
lr_scheduler=self.lr_scheduler.state_dict()
if self.lr_scheduler is not None else None,
optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None,
param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None,
frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func)
if save_frozen_param else None,
shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None,
frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func)
if save_frozen_param else None,
lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
data_sampler=self.training_dataloader.data_sampler.state_dict() if
(self.training_dataloader is not None
and self.curriculum_learning_enabled()) else None,
random_ltd=self.random_ltd_scheduler.state_dict()
if self.random_ltd_enabled() else None,
(self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,
sparse_tensor_module_names=self.sparse_tensor_module_names,
skipped_steps=self.skipped_steps,
global_steps=self.global_steps,
......@@ -3348,6 +3112,25 @@ class DeepSpeedEngine(Module):
return buffer_names
def _get_param_shape_func(self, param):
return param.ds_shape if hasattr(param, 'ds_id') else param.shape
def _get_param_fragment_func(self, param):
return param.ds_tensor.detach().cpu() if hasattr(param, 'ds_id') else param.detach().cpu()
def _get_zero_frozen_param_attributes(self, attr_func):
frozen_param_fragments = OrderedDict()
for param in self.module.parameters():
if param.requires_grad:
continue
if param not in self.param_names:
raise ValueError(f"failed to find frozen {param} in named params")
name = self.param_names[param]
frozen_param_fragments[name] = attr_func(param)
return frozen_param_fragments
def _get_zero_param_shapes(self):
"""Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the
optimizer. the names are exactly as in state_dict. The order is absolutely important, since
......@@ -3390,6 +3173,40 @@ class DeepSpeedEngine(Module):
return param_group_shapes
def _get_shared_params(self):
"""
Returns a dict of shared params, which can later be used to reconstruct the original state dict,
e.g. in `zero_to_fp32`. Each dict entry is a pair of param names, where the key is the name
of the variable that isn't stored and the value is the actual param holding data.
"""
shared_ds_ids = {}
shared_params_by_full_name = {}
def get_layer_state_dict(module, prefix=""):
# handle params
for name, param in module.named_parameters(recurse=False):
if param is None or not hasattr(param, "ds_id"):
continue
key = prefix + name
# can't rely on param.data_ptr() as it will be reused as weights gets
# gathered and reduced, but param.ds_id is unique across all zero weights
# (and shared params will have the same param.ds_id)
if param.ds_id in shared_ds_ids:
# shared weights
#print(f"`{key}` is shared with `{shared_ds_ids[param.ds_id]}`")
shared_params_by_full_name[key] = shared_ds_ids[param.ds_id]
else:
shared_ds_ids[param.ds_id] = key
for name, child in module.named_children():
if child is not None:
get_layer_state_dict(child, prefix + name + ".")
if dist.get_rank() == 0:
get_layer_state_dict(self.module, prefix="")
return shared_params_by_full_name
def _copy_recovery_script(self, save_path):
base_dir = os.path.dirname(os.path.dirname(__file__))
script = "zero_to_fp32.py"
......@@ -3402,9 +3219,7 @@ class DeepSpeedEngine(Module):
def _save_zero_checkpoint(self, save_path, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(),
ds_config=self.config,
ds_version=version)
zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version)
self.checkpoint_engine.save(zero_sd, zero_checkpoint_name)
if self.global_rank == 0:
......@@ -3434,9 +3249,7 @@ class DeepSpeedEngine(Module):
# gather one layer at a time to be memory-efficient
# must use modifier_rank=0 to release GPU memory after each layer gathered
#see_memory_usage("before GatheredParameters", force=True)
with deepspeed.zero.GatheredParameters(list(
module.parameters(recurse=False)),
modifier_rank=0):
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if dist.get_rank() == 0:
# handle params
for name, param in module.named_parameters(recurse=False):
......@@ -3457,8 +3270,7 @@ class DeepSpeedEngine(Module):
# now buffers - not sure if need to take care of potentially shared weights here
for name, buf in module.named_buffers(recurse=False):
if (buf is not None
and name not in module._non_persistent_buffers_set):
if (buf is not None and name not in module._non_persistent_buffers_set):
state_dict[prefix + name] = buf.detach().cpu()
#see_memory_usage("after GatheredParameters", force=True)
......@@ -3511,15 +3323,29 @@ class DeepSpeedEngine(Module):
else:
# the model will be bogus if not consolidated so don't confuse the user by saving it
logger.info(
f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False"
)
f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False")
return False
else:
state_dict = self.module.state_dict()
tag = f"global_step{self.global_steps}"
tag = str(tag)
self.checkpoint_engine.create(tag)
if dist.get_rank() == 0:
os.makedirs(save_dir, exist_ok=True)
logger.info(f"Saving model weights to {path}")
self.checkpoint_engine.makedirs(save_dir, exist_ok=True)
logger.info(f"Saving model weights to {path}, tag: {tag}")
self.checkpoint_engine.save(state_dict, path)
self.checkpoint_engine.commit(tag)
return True
def empty_partition_cache(self):
"""
Release GPU memory consumed by offloaded model parameters.
"""
if hasattr(self.optimizer, 'empty_partition_cache'):
self.optimizer.empty_partition_cache()
gc.collect()
get_accelerator().empty_cache()
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''
'''
Copyright 2019 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/apex
This file is adapted from FP16_Optimizer in NVIDIA/apex
'''
"""
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
......@@ -23,6 +25,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
For usage example please see, TODO: DeepSpeed V2 Tutorial
"""
def __init__(self,
init_optimizer,
deepspeed=None,
......@@ -58,20 +61,15 @@ class FP16_Optimizer(DeepSpeedOptimizer):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# init fp16 weight buffer, flattened
self.fp16_groups_flat.append(
_flatten_dense_tensors([p.clone().detach()
for p in self.fp16_groups[i]]))
self.fp16_groups_flat.append(_flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]]))
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
# init master weight, flattened
self.fp32_groups_flat.append(
self.fp16_groups_flat[i].clone().float().detach())
self.fp32_groups_flat.append(self.fp16_groups_flat[i].clone().float().detach())
# modify optimizer of have flat master weight
self.fp32_groups_flat[
i].requires_grad = True # keep this in case internal optimizer uses it
self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.fp32_groups_flat[i]]
# we may have a way of fusing dynamic scale. Do not support for now
......@@ -113,16 +111,13 @@ class FP16_Optimizer(DeepSpeedOptimizer):
self.mpu = mpu
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups,
mpu=self.mpu,
deepspeed=deepspeed)
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed)
self.initialize_optimizer_states()
def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
self.fp32_groups_flat[i].grad = torch.zeros(
self.fp32_groups_flat[i].size(),
device=self.fp32_groups_flat[i].device)
self.fp32_groups_flat[i].grad = torch.zeros(self.fp32_groups_flat[i].size(),
device=self.fp32_groups_flat[i].device)
self.optimizer.step()
......@@ -156,10 +151,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
for i, group in enumerate(self.fp16_groups):
grads_groups_flat.append(
_flatten_dense_tensors([
torch.zeros(p.size(),
dtype=p.dtype,
device=p.device) if p.grad is None else p.grad
for p in group
torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group
]))
norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
......@@ -169,17 +161,13 @@ class FP16_Optimizer(DeepSpeedOptimizer):
if self.overflow:
if self.verbose:
logger.info(
"[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
return self.overflow
scaled_grad_norm = get_global_norm(norm_list=norm_groups)
combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
scaled_grad_norm,
apply_scale=False)
combined_scale = self.unscale_and_clip_grads(grads_groups_flat, scaled_grad_norm, apply_scale=False)
# Stash unscaled gradient norm
self._global_grad_norm = scaled_grad_norm / self.cur_scale
......@@ -191,8 +179,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
grad_norms=norm_groups)
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
return self.overflow
......@@ -222,9 +209,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
def override_loss_scale(self, loss_scale):
if loss_scale != self.external_loss_scale:
logger.info(
f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}'
)
logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale
......@@ -273,10 +258,8 @@ class FP16_Optimizer(DeepSpeedOptimizer):
grads_groups_flat.append(
_flatten_dense_tensors([
torch.zeros(p.size(),
dtype=data_type,
device=p.device)
if p.grad is None else p.grad.to(data_type) for p in group
torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type)
for p in group
]))
for p in group:
......@@ -313,8 +296,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
self.start_timers([UPDATE_FP16])
for i in range(len(self.fp16_groups)):
updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i],
self.fp16_groups[i])
updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data.copy_(q.data)
......@@ -334,9 +316,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
else:
pg = groups._get_data_parallel_group()
scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg))
scaled_norm_tensor = torch.tensor(scaled_norm,
device=self.fp32_groups_flat[0].device,
dtype=torch.float)
scaled_norm_tensor = torch.tensor(scaled_norm, device=self.fp32_groups_flat[0].device, dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
all_groups_norm = scaled_norm_tensor.item()
#print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
......@@ -376,25 +356,19 @@ class FP16_Optimizer(DeepSpeedOptimizer):
if self.dynamic_loss_scale:
prev_scale = self.cur_scale
if skip:
self.cur_scale = max(self.cur_scale / self.scale_factor,
self.min_loss_scale)
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale)
self.last_overflow_iter = self.cur_iter
if self.verbose:
logger.info(f"\nGrad overflow on iteration {self.cur_iter}")
logger.info(
f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}"
)
logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}")
else:
# Ensure self.scale_window updates since last overflow
stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
self.cur_scale *= self.scale_factor
if self.verbose:
logger.info(
f"No Grad overflow for {self.scale_window} iterations")
logger.info(
f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}"
)
logger.info(f"No Grad overflow for {self.scale_window} iterations")
logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}")
else:
if skip:
logger.info("Grad overflow on iteration: %s", self.cur_iter)
......
# Copyright 2019 The Microsoft DeepSpeed Team
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Taken and modified for DeepSpeed from:
# https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
#Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Taken and modified for DeepSpeed from:
https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
"""
import torch
from deepspeed import comm as dist
from deepspeed.utils import logger
INITIAL_LOSS_SCALE = 'init_scale'
SCALE_WINDOW = 'scale_window'
......@@ -35,6 +42,7 @@ class LossScalerBase:
"""LossScalarBase
Base class for a loss scaler
"""
def __init__(self, cur_scale):
self.cur_scale = cur_scale
self.dynamic = False
......@@ -52,6 +60,7 @@ class LossScalerBase:
def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
# print(f'LossScalerBackward: {scaled_loss=}')
class LossScaler(LossScalerBase):
......@@ -65,6 +74,7 @@ class LossScaler(LossScalerBase):
Args:
scale (float, optional, default=1.0): The loss scale.
"""
def __init__(self, scale=1):
super(LossScaler, self).__init__(scale)
......@@ -102,6 +112,7 @@ class DynamicLossScaler(LossScalerBase):
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
"""
def __init__(self,
init_scale=2**32,
scale_factor=2.,
......@@ -109,7 +120,8 @@ class DynamicLossScaler(LossScalerBase):
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False,
raise_error_at_min_scale=True):
raise_error_at_min_scale=True,
dtype=torch.half):
super(DynamicLossScaler, self).__init__(init_scale)
self.cur_iter = 0
self.last_overflow_iter = -1
......@@ -121,6 +133,7 @@ class DynamicLossScaler(LossScalerBase):
self.consecutive_hysteresis = consecutive_hysteresis
self.raise_error_at_min_scale = raise_error_at_min_scale
self.dynamic = True
self.dtype = dtype
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
......@@ -158,10 +171,21 @@ class DynamicLossScaler(LossScalerBase):
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale:
raise Exception(
"Current loss scale already at minimum - cannot decrease scale anymore. Exiting run."
)
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
"Current loss scale already at minimum - cannot decrease scale anymore. Exiting run.")
else:
next_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
if dist.get_rank() == 0:
overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
if self.dtype == torch.half:
overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, reducing to {int(next_scale)}"
logger.info(overflow_msg)
self.cur_scale = next_scale
else:
if dist.get_rank() == 0:
overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
if self.dtype == torch.half:
overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, but hysteresis is {self.cur_hysteresis}. Reducing hysteresis to {self.cur_hysteresis-1}"
logger.info(overflow_msg)
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
else:
......@@ -179,8 +203,8 @@ class DynamicLossScaler(LossScalerBase):
def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args):
if dtype == torch.half and dynamic_scaling:
if dynamic_loss_args is None:
return DynamicLossScaler()
return DynamicLossScaler(**dynamic_loss_args)
return DynamicLossScaler(dtype=dtype)
return DynamicLossScaler(dtype=dtype, **dynamic_loss_args)
loss_scale_value = static_loss_scale if dtype == torch.half else 1.0
return LossScaler(scale=loss_scale_value)
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .adam import OnebitAdam
from .lamb import OnebitLamb
......
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import types
import torch
import numpy as np
......@@ -39,14 +41,14 @@ class OnebitAdam(torch.optim.Optimizer):
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self,
params,
deepspeed=None,
lr=1e-3,
freeze_step=100000,
bias_correction=True,
betas=(0.9,
0.999),
betas=(0.9, 0.999),
eps=1e-8,
eps_inside_sqrt=False,
weight_decay=0.,
......@@ -89,11 +91,12 @@ class OnebitAdam(torch.optim.Optimizer):
if self.comm_backend_name == 'nccl':
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert (
(TORCH_MAJOR == 1 and TORCH_MINOR >= 8) or TORCH_MAJOR >= 2
), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.using_pipeline = hasattr(self.deepspeed,
'pipeline_enable_backward_allreduce')
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)
elif self.comm_backend_name == 'mpi':
......@@ -164,22 +167,17 @@ class OnebitAdam(torch.optim.Optimizer):
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if not self.initialize or (self.adam_freeze_key
and 'worker_error' not in state.keys()):
if not self.initialize or (self.adam_freeze_key and 'worker_error' not in state.keys()):
state['tensor_size'] = torch.numel(p.data)
state['corrected_tensor_size'] = state['tensor_size']
if state['tensor_size'] % (self.size * self.divider) != 0:
state['corrected_tensor_size'] += ((self.size * self.divider) -
(state['tensor_size'] %
(self.size * self.divider)))
state['server_chunk_size'] = state[
'corrected_tensor_size'] // self.size
state['corrected_tensor_size'] += ((self.size * self.divider) - (state['tensor_size'] %
(self.size * self.divider)))
state['server_chunk_size'] = state['corrected_tensor_size'] // self.size
get_accelerator().empty_cache()
state['worker_error'] = torch.zeros(state['corrected_tensor_size'],
device=p.device)
state['server_error'] = torch.zeros(state['server_chunk_size'],
device=p.device)
state['worker_error'] = torch.zeros(state['corrected_tensor_size'], device=p.device)
state['server_error'] = torch.zeros(state['server_chunk_size'], device=p.device)
get_accelerator().empty_cache()
self.adam_freeze_key = True
if not self.initialize and dist.get_rank() == 0:
......@@ -211,11 +209,9 @@ class OnebitAdam(torch.optim.Optimizer):
if self.size > 1:
exp_avg.set_(
self.comm_backend_handle.compressed_allreduce(
exp_avg,
state['worker_error'],
state['server_error'],
self.deepspeed.local_rank))
self.comm_backend_handle.compressed_allreduce(exp_avg, state['worker_error'],
state['server_error'],
self.deepspeed.local_rank))
# Because 1-bit compression cannot represent exact zero, it is required to
# provide a momentum mask for those params that have constant exact zeros in their
# momentums, otherwise the compression error would keep accumulating.
......@@ -225,8 +221,7 @@ class OnebitAdam(torch.optim.Optimizer):
# (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
if 'exp_avg_mask' in group:
if exp_avg.device != group['exp_avg_mask'].device:
group['exp_avg_mask'] = group['exp_avg_mask'].to(
device=exp_avg.device)
group['exp_avg_mask'] = group['exp_avg_mask'].to(device=exp_avg.device)
exp_avg.mul_(group['exp_avg_mask'])
if self.initialize:
......@@ -272,8 +267,7 @@ class OnebitAdam(torch.optim.Optimizer):
for i, group in enumerate(self.param_groups):
if 'exp_avg_mask' in group:
state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask']
elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[
'param_groups'][i]:
elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict['param_groups'][i]:
state_dict['param_groups'][i].pop('exp_avg_mask')
super().load_state_dict(state_dict)
if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step:
......@@ -287,9 +281,7 @@ class OnebitAdam(torch.optim.Optimizer):
self.deepspeed.enable_backward_allreduce = True
else:
if dist.get_rank() == 0:
print(
"Checkpoint loaded and OnebitAdam compression stage starts/continues."
)
print("Checkpoint loaded and OnebitAdam compression stage starts/continues.")
if self.adam_freeze_key is False:
self.adam_freeze_key = True
if self.using_pipeline:
......
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import types
import torch
import numpy as np
......@@ -54,14 +56,14 @@ class OnebitLamb(torch.optim.Optimizer):
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self,
params,
deepspeed=None,
lr=1e-3,
freeze_step=100000,
bias_correction=True,
betas=(0.9,
0.999),
betas=(0.9, 0.999),
eps=1e-8,
eps_inside_sqrt=False,
weight_decay=0.,
......@@ -111,11 +113,12 @@ class OnebitLamb(torch.optim.Optimizer):
if self.comm_backend_name == 'nccl':
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert (
(TORCH_MAJOR == 1 and TORCH_MINOR >= 8) or TORCH_MAJOR >= 2
), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.using_pipeline = hasattr(self.deepspeed,
'pipeline_enable_backward_allreduce')
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)
elif self.comm_backend_name == 'mpi':
......@@ -165,24 +168,20 @@ class OnebitLamb(torch.optim.Optimizer):
if self.lamb_freeze_key:
exp_avg_last_step = []
for group in self.param_groups:
exp_avg_last_step.append(
[self.state[p]['exp_avg'].detach().clone() for p in group['params']])
exp_avg_last_step.append([self.state[p]['exp_avg'].detach().clone() for p in group['params']])
if 'scaling_coeff' not in self.state[self.param_groups[0]['params'][0]]:
# Compute the scaling_coeff for each momentum at the end of warmup stage.
# This is used to reduce compression error during compression stage.
momentum_scales = []
for group in self.param_groups:
momentum_scales.append([
(torch.norm(self.state[p]['exp_avg']) /
np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
(torch.norm(self.state[p]['exp_avg']) / np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
for p in group['params']
])
united_scale = sum([sum(x) for x in momentum_scales]) / sum(
[len(x) for x in momentum_scales])
united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])
for i, group in enumerate(self.param_groups):
for j, p in enumerate(group['params']):
self.state[p][
'scaling_coeff'] = united_scale / momentum_scales[i][j]
self.state[p]['scaling_coeff'] = united_scale / momentum_scales[i][j]
for group, grads_this_group in zip(self.param_groups, grads_group):
if grads_this_group is None:
......@@ -201,8 +200,7 @@ class OnebitLamb(torch.optim.Optimizer):
state = self.state[p]
# State initialization
if len(state) == 0 or (len(state) == 1
and 'scaling_coeff' in state.keys()):
if len(state) == 0 or (len(state) == 1 and 'scaling_coeff' in state.keys()):
state['step'] = 0
state['lamb_coeff_freeze'] = 0.0
state['last_factor'] = 1.0
......@@ -215,7 +213,8 @@ class OnebitLamb(torch.optim.Optimizer):
if not self.initialize:
self.lamb_freeze_key = True
exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_sq_fresh']
exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state[
'exp_avg_sq_fresh']
beta1, beta2 = group['betas']
max_coeff = group['max_coeff']
min_coeff = group['min_coeff']
......@@ -243,8 +242,8 @@ class OnebitLamb(torch.optim.Optimizer):
if lamb_coeff < min_coeff:
lamb_coeff = min_coeff
if lamb_coeff != 1.0:
state['lamb_coeff_freeze'] = self.coeff_beta * state[
'lamb_coeff_freeze'] + (1 - self.coeff_beta) * lamb_coeff
state['lamb_coeff_freeze'] = self.coeff_beta * state['lamb_coeff_freeze'] + (
1 - self.coeff_beta) * lamb_coeff
self.lamb_coeffs.append(lamb_coeff)
with torch.no_grad():
p.add_(-group['lr'] * lamb_coeff * update)
......@@ -266,20 +265,15 @@ class OnebitLamb(torch.optim.Optimizer):
tensor_size += torch.numel(p.data)
corrected_tensor_size = tensor_size
if tensor_size % (self.size * self.divider) != 0:
difference = ((self.size * self.divider) - (tensor_size %
(self.size * self.divider)))
difference = ((self.size * self.divider) - (tensor_size % (self.size * self.divider)))
corrected_tensor_size += difference
self.dummy_exp_avg[0] = torch.zeros(
difference,
device=momentum_groups[0].data.device)
self.dummy_exp_avg[0] = torch.zeros(difference, device=momentum_groups[0].data.device)
momentum_groups.append(self.dummy_exp_avg[0])
self.corrected_tensor_sizes.append(corrected_tensor_size)
self.server_chunk_sizes.append(corrected_tensor_size // self.size)
self.exp_avg_flat.append(
_flatten_dense_tensors([p.detach().clone() for p in momentum_groups]))
updated_params = _unflatten_dense_tensors(self.exp_avg_flat[0],
momentum_groups)
self.exp_avg_flat.append(_flatten_dense_tensors([p.detach().clone() for p in momentum_groups]))
updated_params = _unflatten_dense_tensors(self.exp_avg_flat[0], momentum_groups)
for p, q in zip(momentum_groups, updated_params):
p.data = q.data
......@@ -287,11 +281,8 @@ class OnebitLamb(torch.optim.Optimizer):
get_accelerator().empty_cache()
for i in range(len(self.exp_avg_flat)):
self.worker_errors.append(
torch.zeros(self.corrected_tensor_sizes[i],
device=self.exp_avg_flat[i].device))
self.server_errors.append(
torch.zeros(self.server_chunk_sizes[i],
device=self.exp_avg_flat[i].device))
torch.zeros(self.corrected_tensor_sizes[i], device=self.exp_avg_flat[i].device))
self.server_errors.append(torch.zeros(self.server_chunk_sizes[i], device=self.exp_avg_flat[i].device))
get_accelerator().empty_cache()
if self.lamb_freeze_key:
......@@ -300,31 +291,23 @@ class OnebitLamb(torch.optim.Optimizer):
if not self.initialize:
get_accelerator().empty_cache()
self.worker_errors.append(
torch.zeros(self.corrected_tensor_sizes[i],
device=self.exp_avg_flat[i].device))
torch.zeros(self.corrected_tensor_sizes[i], device=self.exp_avg_flat[i].device))
self.server_errors.append(
torch.zeros(self.server_chunk_sizes[i],
device=self.exp_avg_flat[i].device))
torch.zeros(self.server_chunk_sizes[i], device=self.exp_avg_flat[i].device))
get_accelerator().empty_cache()
if dist.get_rank() == 0:
print("Cupy Buffers Initialized Successfully.")
self.comm_backend_handle.compressed_allreduce(
self.exp_avg_flat[i],
self.worker_errors[0],
self.server_errors[0],
self.deepspeed.local_rank)
self.comm_backend_handle.compressed_allreduce(self.exp_avg_flat[i], self.worker_errors[0],
self.server_errors[0], self.deepspeed.local_rank)
if dist.get_rank() == 0:
print('Pop out errors', flush=True)
del self.worker_errors[:]
del self.server_errors[:]
else:
self.comm_backend_handle.compressed_allreduce(
self.exp_avg_flat[i],
self.worker_errors[i],
self.server_errors[i],
self.deepspeed.local_rank)
self.comm_backend_handle.compressed_allreduce(self.exp_avg_flat[i], self.worker_errors[i],
self.server_errors[i], self.deepspeed.local_rank)
if self.lamb_freeze_key and self.initialize:
for i, group in enumerate(self.param_groups):
......@@ -332,7 +315,8 @@ class OnebitLamb(torch.optim.Optimizer):
for j, p in enumerate(group['params']):
state = self.state[p]
exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_sq_fresh']
exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state[
'exp_avg_sq_fresh']
beta1, beta2 = group['betas']
exp_avg.div_(self.state[p]['scaling_coeff'])
# Because 1-bit compression cannot represent exact zero, it is required to
......@@ -345,15 +329,11 @@ class OnebitLamb(torch.optim.Optimizer):
# to add this exp_avg_mask for BERT pre-training.)
if 'exp_avg_mask' in group:
if exp_avg.device != group['exp_avg_mask'].device:
group['exp_avg_mask'] = group['exp_avg_mask'].to(
device=exp_avg.device)
group['exp_avg_mask'] = group['exp_avg_mask'].to(device=exp_avg.device)
exp_avg.mul_(group['exp_avg_mask'])
grad_reconstruct = ((exp_avg - exp_avg_last_step[i][j] * beta1) /
(1 - beta1))
exp_avg_sq_fresh.mul_(beta2).addcmul_(1 - beta2,
grad_reconstruct,
grad_reconstruct)
grad_reconstruct = ((exp_avg - exp_avg_last_step[i][j] * beta1) / (1 - beta1))
exp_avg_sq_fresh.mul_(beta2).addcmul_(1 - beta2, grad_reconstruct, grad_reconstruct)
denom = exp_avg_sq.sqrt() + group['eps']
update_prelim = exp_avg / denom
......@@ -367,9 +347,7 @@ class OnebitLamb(torch.optim.Optimizer):
denom_real = exp_avg_sq_fresh.sqrt() + group['eps']
factor = (denom / denom_real).max().item()
if group['weight_decay'] > 0.0:
update_ratio = min(1.0,
(update_prelim.pow(2).sum().sqrt() /
update_norm).item())
update_ratio = min(1.0, (update_prelim.pow(2).sum().sqrt() / update_norm).item())
factor = factor * update_ratio + (1.0 - update_ratio)
if factor > self.factor_max:
factor = self.factor_max
......@@ -416,8 +394,7 @@ class OnebitLamb(torch.optim.Optimizer):
for i, group in enumerate(self.param_groups):
if 'exp_avg_mask' in group:
state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask']
elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[
'param_groups'][i]:
elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict['param_groups'][i]:
state_dict['param_groups'][i].pop('exp_avg_mask')
super().load_state_dict(state_dict)
# need to reset the fused momentum since loading states will break the linking
......@@ -442,9 +419,7 @@ class OnebitLamb(torch.optim.Optimizer):
self.state[p].pop('scaling_coeff')
else:
if dist.get_rank() == 0:
print(
"Checkpoint loaded and OnebitLamb compression stage starts/continues."
)
print("Checkpoint loaded and OnebitLamb compression stage starts/continues.")
if self.lamb_freeze_key is False:
self.lamb_freeze_key = True
if self.using_pipeline:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment