Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
''' # Copyright (c) Microsoft Corporation.
Copyright 2021 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import math import math
from deepspeed.utils import logger from deepspeed.utils import logger
from .constants import * from .constants import *
class CurriculumScheduler(object): class CurriculumScheduler(object):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.state = {} self.state = {}
...@@ -16,17 +19,12 @@ class CurriculumScheduler(object): ...@@ -16,17 +19,12 @@ class CurriculumScheduler(object):
f"Curriculum learning requires the config '{CURRICULUM_LEARNING_MAX_DIFFICULTY}'" f"Curriculum learning requires the config '{CURRICULUM_LEARNING_MAX_DIFFICULTY}'"
assert CURRICULUM_LEARNING_SCHEDULE_TYPE in config, \ assert CURRICULUM_LEARNING_SCHEDULE_TYPE in config, \
f"Curriculum learning requires the config '{CURRICULUM_LEARNING_SCHEDULE_TYPE}'" f"Curriculum learning requires the config '{CURRICULUM_LEARNING_SCHEDULE_TYPE}'"
self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY] = config[ self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY] = config[CURRICULUM_LEARNING_MIN_DIFFICULTY]
CURRICULUM_LEARNING_MIN_DIFFICULTY] self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY] = config[CURRICULUM_LEARNING_MAX_DIFFICULTY]
self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY] = config[ self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] = config[CURRICULUM_LEARNING_MIN_DIFFICULTY]
CURRICULUM_LEARNING_MAX_DIFFICULTY] self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] = config[CURRICULUM_LEARNING_SCHEDULE_TYPE]
self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] = config[
CURRICULUM_LEARNING_MIN_DIFFICULTY]
self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] = config[
CURRICULUM_LEARNING_SCHEDULE_TYPE]
self.first_step = True self.first_step = True
if config[ if config[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_DISCRETE:
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_DISCRETE:
""" """
The schedule_config is a list of difficulty and a list of max The schedule_config is a list of difficulty and a list of max
step belonging to each difficulty. Example json config: step belonging to each difficulty. Example json config:
...@@ -43,18 +41,12 @@ class CurriculumScheduler(object): ...@@ -43,18 +41,12 @@ class CurriculumScheduler(object):
f"Curriculum learning with fixed_discrete schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY}'" f"Curriculum learning with fixed_discrete schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY}'"
assert CURRICULUM_LEARNING_SCHEDULE_MAX_STEP in config[CURRICULUM_LEARNING_SCHEDULE_CONFIG], \ assert CURRICULUM_LEARNING_SCHEDULE_MAX_STEP in config[CURRICULUM_LEARNING_SCHEDULE_CONFIG], \
f"Curriculum learning with fixed_discrete schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_MAX_STEP}'" f"Curriculum learning with fixed_discrete schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_MAX_STEP}'"
assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG] assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_MAX_STEP]) > 0
[CURRICULUM_LEARNING_SCHEDULE_MAX_STEP]) > 0 assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY]) > 0
assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG] assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY]) == len(
[CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY]) > 0 config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_MAX_STEP]) + 1
assert len(config[CURRICULUM_LEARNING_SCHEDULE_CONFIG] self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
[CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY]) == len( elif config[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT:
config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
[CURRICULUM_LEARNING_SCHEDULE_MAX_STEP]) + 1
self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[
CURRICULUM_LEARNING_SCHEDULE_CONFIG]
elif config[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT:
""" """
The schedule_config includes: The schedule_config includes:
total_curriculum_step: how many steps the curriculum learning takes to go total_curriculum_step: how many steps the curriculum learning takes to go
...@@ -79,15 +71,12 @@ class CurriculumScheduler(object): ...@@ -79,15 +71,12 @@ class CurriculumScheduler(object):
f"Curriculum learning with fixed_root schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP}'" f"Curriculum learning with fixed_root schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP}'"
assert CURRICULUM_LEARNING_SCHEDULE_ROOT_DEGREE in config[CURRICULUM_LEARNING_SCHEDULE_CONFIG], \ assert CURRICULUM_LEARNING_SCHEDULE_ROOT_DEGREE in config[CURRICULUM_LEARNING_SCHEDULE_CONFIG], \
f"Curriculum learning with fixed_root schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_ROOT_DEGREE}'" f"Curriculum learning with fixed_root schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_ROOT_DEGREE}'"
if config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][ if config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP] % 8 != 0:
CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP] % 8 != 0:
logger.warning( logger.warning(
f'When using seqlen metric, the difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your metric/hardware.' f'When using seqlen metric, the difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your metric/hardware.'
) )
self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[ self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
CURRICULUM_LEARNING_SCHEDULE_CONFIG] elif config[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_LINEAR:
elif config[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_LINEAR:
""" """
The schedule_config is the same as CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT but without the The schedule_config is the same as CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT but without the
root_degree. root_degree.
...@@ -100,15 +89,12 @@ class CurriculumScheduler(object): ...@@ -100,15 +89,12 @@ class CurriculumScheduler(object):
f"Curriculum learning with fixed_linear schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_TOTAL_STEP}'" f"Curriculum learning with fixed_linear schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_TOTAL_STEP}'"
assert CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP in config[CURRICULUM_LEARNING_SCHEDULE_CONFIG], \ assert CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP in config[CURRICULUM_LEARNING_SCHEDULE_CONFIG], \
f"Curriculum learning with fixed_linear schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP}'" f"Curriculum learning with fixed_linear schedule requires the schedule_config '{CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP}'"
if config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][ if config[CURRICULUM_LEARNING_SCHEDULE_CONFIG][CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP] % 8 != 0:
CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP] % 8 != 0:
logger.warning( logger.warning(
f'When using seqlen metric, the difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your metric/hardware.' f'When using seqlen metric, the difficulty_step for curriculum learning has to be multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. Disregard this warning if this is unrelated to your metric/hardware.'
) )
self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[ self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] = config[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
CURRICULUM_LEARNING_SCHEDULE_CONFIG] elif config[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_CUSTOM:
elif config[
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_CUSTOM:
""" """
Fully customized schedule. User need to provide a custom schedule Fully customized schedule. User need to provide a custom schedule
function by using the set_custom_curriculum_learning_schedule API function by using the set_custom_curriculum_learning_schedule API
...@@ -145,38 +131,28 @@ class CurriculumScheduler(object): ...@@ -145,38 +131,28 @@ class CurriculumScheduler(object):
s_state = self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG] s_state = self.state[CURRICULUM_LEARNING_SCHEDULE_CONFIG]
if root_degree is None: if root_degree is None:
root_degree = s_state[CURRICULUM_LEARNING_SCHEDULE_ROOT_DEGREE] root_degree = s_state[CURRICULUM_LEARNING_SCHEDULE_ROOT_DEGREE]
next_difficulty = (float(global_steps) / next_difficulty = (float(global_steps) / s_state[CURRICULUM_LEARNING_SCHEDULE_TOTAL_STEP])**(1.0 / root_degree)
s_state[CURRICULUM_LEARNING_SCHEDULE_TOTAL_STEP])**( next_difficulty = math.floor(
1.0 / root_degree) next_difficulty *
next_difficulty = math.floor(next_difficulty * (self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY] - self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY]) +
(self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY] - self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY])
self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY]) + next_difficulty -= (next_difficulty % s_state[CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP])
self.state[CURRICULUM_LEARNING_MIN_DIFFICULTY]) next_difficulty = min(next_difficulty, self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY])
next_difficulty -= (next_difficulty %
s_state[CURRICULUM_LEARNING_SCHEDULE_DIFFICULTY_STEP])
next_difficulty = min(next_difficulty,
self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY])
return next_difficulty return next_difficulty
def get_difficulty(self, global_steps): def get_difficulty(self, global_steps):
if self.state[ if self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_DISCRETE:
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_DISCRETE:
return self.__fixed_discrete_get_difficulty(global_steps) return self.__fixed_discrete_get_difficulty(global_steps)
elif self.state[ elif self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_LINEAR:
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_LINEAR:
return self.__fixed_root_get_difficulty(global_steps, 1) return self.__fixed_root_get_difficulty(global_steps, 1)
elif self.state[ elif self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT:
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_FIXED_ROOT:
return self.__fixed_root_get_difficulty(global_steps) return self.__fixed_root_get_difficulty(global_steps)
elif self.state[ elif self.state[CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_CUSTOM:
CURRICULUM_LEARNING_SCHEDULE_TYPE] == CURRICULUM_LEARNING_SCHEDULE_CUSTOM:
return self.custom_get_difficulty(global_steps) return self.custom_get_difficulty(global_steps)
else: else:
raise RuntimeError('Unsupported curriculum schedule type') raise RuntimeError('Unsupported curriculum schedule type')
def update_difficulty(self, global_steps): def update_difficulty(self, global_steps):
if self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] < self.state[ if self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] < self.state[CURRICULUM_LEARNING_MAX_DIFFICULTY]:
CURRICULUM_LEARNING_MAX_DIFFICULTY]: self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] = self.get_difficulty(global_steps)
self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] = self.get_difficulty(
global_steps)
return self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY] return self.state[CURRICULUM_LEARNING_CURRENT_DIFFICULTY]
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from deepspeed.utils import logger from deepspeed.utils import logger
from torch import Tensor from torch import Tensor
...@@ -14,6 +15,7 @@ class RandomLayerTokenDrop(Module): ...@@ -14,6 +15,7 @@ class RandomLayerTokenDrop(Module):
""" """
A layer wrapper for random LTD A layer wrapper for random LTD
""" """
def __init__(self, layer: Module): def __init__(self, layer: Module):
super(RandomLayerTokenDrop, self).__init__() super(RandomLayerTokenDrop, self).__init__()
self.random_ltd_layer = layer self.random_ltd_layer = layer
...@@ -52,9 +54,7 @@ class RandomLayerTokenDrop(Module): ...@@ -52,9 +54,7 @@ class RandomLayerTokenDrop(Module):
elif self.model_type == 'decoder': elif self.model_type == 'decoder':
self.index_generator = gpt_sample_tokens self.index_generator = gpt_sample_tokens
else: else:
logger.warning( logger.warning("************For now, we only support encoder-only or decoder-only models************")
"************For now, we only support encoder-only or decoder-only models************"
)
raise NotImplementedError raise NotImplementedError
def get_bsh(self, hidden_stats): def get_bsh(self, hidden_stats):
...@@ -78,40 +78,36 @@ class RandomLayerTokenDrop(Module): ...@@ -78,40 +78,36 @@ class RandomLayerTokenDrop(Module):
self.curr_micro_batch, \ self.curr_micro_batch, \
self.random_ltd_num_layer, \ self.random_ltd_num_layer, \
hidden_states.device, mask) hidden_states.device, mask)
self.random_ltd_scheduler.state[ self.random_ltd_scheduler.state[RANDOM_LTD_SAMPLE_INDEX] = sampled_indices
RANDOM_LTD_SAMPLE_INDEX] = sampled_indices self.random_ltd_scheduler.state[RANDOM_LTD_ATTENTION_MASK] = part_attention_mask
self.random_ltd_scheduler.state[
RANDOM_LTD_ATTENTION_MASK] = part_attention_mask
else: else:
sampled_indices = self.random_ltd_scheduler.state[ sampled_indices = self.random_ltd_scheduler.state[RANDOM_LTD_SAMPLE_INDEX]
RANDOM_LTD_SAMPLE_INDEX] part_attention_mask = self.random_ltd_scheduler.state[RANDOM_LTD_ATTENTION_MASK]
part_attention_mask = self.random_ltd_scheduler.state[
RANDOM_LTD_ATTENTION_MASK]
hidden_states, part_hidden_states = GatherTokens.apply(hidden_states, sampled_indices[self.random_ltd_layer_id,:,:], self.batch_first) hidden_states, part_hidden_states = GatherTokens.apply(hidden_states,
sampled_indices[self.random_ltd_layer_id, :, :],
self.batch_first)
if self.mask_name is not None: if self.mask_name is not None:
if self.model_type == 'encoder': if self.model_type == 'encoder':
kwargs[self.mask_name] = part_attention_mask[ kwargs[self.mask_name] = part_attention_mask[self.random_ltd_layer_id]
self.random_ltd_layer_id]
else: else:
kwargs[self.mask_name] = part_attention_mask kwargs[self.mask_name] = part_attention_mask
outputs = self.random_ltd_layer(part_hidden_states, **kwargs) outputs = self.random_ltd_layer(part_hidden_states, **kwargs)
if isinstance(outputs, tuple): if isinstance(outputs, tuple):
hidden_states = ScatterTokens.apply(hidden_states, outputs[0], sampled_indices[self.random_ltd_layer_id,:,:], self.batch_first) hidden_states = ScatterTokens.apply(hidden_states, outputs[0],
sampled_indices[self.random_ltd_layer_id, :, :], self.batch_first)
my_list = list(outputs) my_list = list(outputs)
my_list[0] = hidden_states my_list[0] = hidden_states
return tuple(my_list) return tuple(my_list)
elif isinstance(outputs, Tensor): elif isinstance(outputs, Tensor):
hidden_states = ScatterTokens.apply(hidden_states, outputs, sampled_indices[self.random_ltd_layer_id,:,:], self.batch_first) hidden_states = ScatterTokens.apply(hidden_states, outputs,
sampled_indices[self.random_ltd_layer_id, :, :], self.batch_first)
return hidden_states return hidden_states
else: else:
logger.warning( logger.warning("************For now, we only support tuple and tensor output. \
"************For now, we only support tuple and tensor output. \ You need to adjust the output according to the layer in your model************")
You need to adjust the output according to the layer in your model************"
)
raise NotImplementedError raise NotImplementedError
else: else:
return self.random_ltd_layer(hidden_states, **kwargs) return self.random_ltd_layer(hidden_states, **kwargs)
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from .basic_layer import RandomLayerTokenDrop from .basic_layer import RandomLayerTokenDrop
from collections import OrderedDict from collections import OrderedDict
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import math import math
...@@ -12,6 +13,7 @@ from ..constants import * ...@@ -12,6 +13,7 @@ from ..constants import *
class BaseScheduler(object): class BaseScheduler(object):
def __init__(self): def __init__(self):
self.state = {} self.state = {}
...@@ -19,12 +21,9 @@ class BaseScheduler(object): ...@@ -19,12 +21,9 @@ class BaseScheduler(object):
s_state = self.state[RANDOM_LTD_SCHEDULE_CONFIG] s_state = self.state[RANDOM_LTD_SCHEDULE_CONFIG]
if root_degree is None: if root_degree is None:
root_degree = s_state['root_degree'] root_degree = s_state['root_degree']
next_seq = (float(global_steps) / next_seq = (float(global_steps) / s_state[RANDOM_LTD_REQUIRE_STEP])**(1.0 / root_degree)
s_state[RANDOM_LTD_REQUIRE_STEP])**(1.0 / root_degree) next_seq = math.floor(next_seq * (self.state[RANDOM_LTD_MAX_VALUE] - self.state[RANDOM_LTD_MIN_VALUE]) +
next_seq = math.floor( self.state[RANDOM_LTD_MIN_VALUE])
next_seq *
(self.state[RANDOM_LTD_MAX_VALUE] - self.state[RANDOM_LTD_MIN_VALUE]) +
self.state[RANDOM_LTD_MIN_VALUE])
next_seq -= (next_seq % s_state[RANDOM_LTD_INCREASE_STEP]) next_seq -= (next_seq % s_state[RANDOM_LTD_INCREASE_STEP])
next_seq = min(next_seq, self.state[RANDOM_LTD_MAX_VALUE]) next_seq = min(next_seq, self.state[RANDOM_LTD_MAX_VALUE])
return next_seq return next_seq
...@@ -37,6 +36,7 @@ class BaseScheduler(object): ...@@ -37,6 +36,7 @@ class BaseScheduler(object):
class RandomLTDScheduler(BaseScheduler): class RandomLTDScheduler(BaseScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.model_layer_num = config[RANDOM_LTD_TOTAL_LAYER_NUM] self.model_layer_num = config[RANDOM_LTD_TOTAL_LAYER_NUM]
...@@ -61,12 +61,9 @@ class RandomLTDScheduler(BaseScheduler): ...@@ -61,12 +61,9 @@ class RandomLTDScheduler(BaseScheduler):
if self.config_schedule is not None: if self.config_schedule is not None:
self.state[RANDOM_LTD_MIN_VALUE] = self.config_schedule[RANDOM_LTD_MIN_VALUE] self.state[RANDOM_LTD_MIN_VALUE] = self.config_schedule[RANDOM_LTD_MIN_VALUE]
self.state[RANDOM_LTD_MAX_VALUE] = self.config_schedule[RANDOM_LTD_MAX_VALUE] self.state[RANDOM_LTD_MAX_VALUE] = self.config_schedule[RANDOM_LTD_MAX_VALUE]
self.state[RANDOM_LTD_CURRENT_VALUE] = self.config_schedule[ self.state[RANDOM_LTD_CURRENT_VALUE] = self.config_schedule[RANDOM_LTD_MIN_VALUE]
RANDOM_LTD_MIN_VALUE] self.state[RANDOM_LTD_SCHEDULE_CONFIG] = self.config_schedule[RANDOM_LTD_SCHEDULE_CONFIG]
self.state[RANDOM_LTD_SCHEDULE_CONFIG] = self.config_schedule[ self.state[RANDOM_LTD_SCHEDULER_TYPE] = self.config_schedule[RANDOM_LTD_SCHEDULER_TYPE]
RANDOM_LTD_SCHEDULE_CONFIG]
self.state[RANDOM_LTD_SCHEDULER_TYPE] = self.config_schedule[
RANDOM_LTD_SCHEDULER_TYPE]
self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = 0 self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = 0
self.state[RANDOM_LTD_CURR_STEP] = -1 self.state[RANDOM_LTD_CURR_STEP] = -1
...@@ -95,8 +92,7 @@ class RandomLTDScheduler(BaseScheduler): ...@@ -95,8 +92,7 @@ class RandomLTDScheduler(BaseScheduler):
def state_dict(self): def state_dict(self):
return { return {
RANDOM_LTD_CONSUMED_LAYER_TOKENS: RANDOM_LTD_CONSUMED_LAYER_TOKENS: self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS],
self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS],
RANDOM_LTD_CURR_STEP: self.state[RANDOM_LTD_CURR_STEP], RANDOM_LTD_CURR_STEP: self.state[RANDOM_LTD_CURR_STEP],
RANDOM_LTD_CURRENT_VALUE: self.state[RANDOM_LTD_CURRENT_VALUE], RANDOM_LTD_CURRENT_VALUE: self.state[RANDOM_LTD_CURRENT_VALUE],
RANDOM_LTD_MIN_VALUE: self.state[RANDOM_LTD_MIN_VALUE], RANDOM_LTD_MIN_VALUE: self.state[RANDOM_LTD_MIN_VALUE],
...@@ -104,8 +100,7 @@ class RandomLTDScheduler(BaseScheduler): ...@@ -104,8 +100,7 @@ class RandomLTDScheduler(BaseScheduler):
} }
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = state_dict[ self.state[RANDOM_LTD_CONSUMED_LAYER_TOKENS] = state_dict[RANDOM_LTD_CONSUMED_LAYER_TOKENS]
RANDOM_LTD_CONSUMED_LAYER_TOKENS]
self.state[RANDOM_LTD_CURR_STEP] = state_dict[RANDOM_LTD_CURR_STEP] self.state[RANDOM_LTD_CURR_STEP] = state_dict[RANDOM_LTD_CURR_STEP]
self.state[RANDOM_LTD_CURRENT_VALUE] = state_dict[RANDOM_LTD_CURRENT_VALUE] self.state[RANDOM_LTD_CURRENT_VALUE] = state_dict[RANDOM_LTD_CURRENT_VALUE]
self.state[RANDOM_LTD_MIN_VALUE] = state_dict[RANDOM_LTD_MIN_VALUE] self.state[RANDOM_LTD_MIN_VALUE] = state_dict[RANDOM_LTD_MIN_VALUE]
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
...@@ -10,8 +11,7 @@ def bsh_decoder_gather(reserved_length, hidden_states, mask): ...@@ -10,8 +11,7 @@ def bsh_decoder_gather(reserved_length, hidden_states, mask):
rand_list = [] rand_list = []
part_hidden_states = [] # batch, seq, hidden ## different from megatron part_hidden_states = [] # batch, seq, hidden ## different from megatron
for k in range(hidden_states.size(0)): for k in range(hidden_states.size(0)):
B_tmp = torch.randperm(hidden_states.size(1), B_tmp = torch.randperm(hidden_states.size(1), device=hidden_states.device)[:reserved_length]
device=hidden_states.device)[:reserved_length]
B = B_tmp.sort()[0] B = B_tmp.sort()[0]
rand_list.append(B) rand_list.append(B)
part_hidden_states.append(hidden_states[k:k + 1, B, :]) part_hidden_states.append(hidden_states[k:k + 1, B, :])
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import os import os
from collections import defaultdict from collections import defaultdict
import csv import csv
...@@ -16,6 +18,7 @@ from .utils import split_dataset, split_index, create_mmap_dataset_builder, clos ...@@ -16,6 +18,7 @@ from .utils import split_dataset, split_index, create_mmap_dataset_builder, clos
class DataAnalyzer(object): class DataAnalyzer(object):
def __init__(self, def __init__(self,
dataset, dataset,
num_workers=1, num_workers=1,
...@@ -53,25 +56,19 @@ class DataAnalyzer(object): ...@@ -53,25 +56,19 @@ class DataAnalyzer(object):
self.custom_map_finalize = custom_map_finalize self.custom_map_finalize = custom_map_finalize
self.custom_reduce = custom_reduce self.custom_reduce = custom_reduce
def init_metric_results(self, def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id):
thread_id,
metric_names,
metric_types,
metric_dtypes,
save_path,
worker_id):
metric_results = [] metric_results = []
for m_idx in range(len(metric_names)): for m_idx in range(len(metric_names)):
metric_name, metric_type, metric_dtype = metric_names[m_idx], \ metric_name, metric_type, metric_dtype = metric_names[m_idx], \
metric_types[m_idx], metric_dtypes[m_idx] metric_types[m_idx], metric_dtypes[m_idx]
assert metric_dtype not in [np.float64, np.double], "Currently floating point metric values are not supported. Please change your metric into integer values (and potentially multiply a larger coefficient to keep the precision)." assert metric_dtype not in [
np.float64, np.double
], "Currently floating point metric values are not supported. Please change your metric into integer values (and potentially multiply a larger coefficient to keep the precision)."
metric_save_path = f"{save_path}/{metric_name}/worker{worker_id}_thread{thread_id}/" metric_save_path = f"{save_path}/{metric_name}/worker{worker_id}_thread{thread_id}/"
os.makedirs(metric_save_path, exist_ok=True) os.makedirs(metric_save_path, exist_ok=True)
if metric_type == 'single_value_per_sample': if metric_type == 'single_value_per_sample':
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric" sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
sample_to_metric_builder = create_mmap_dataset_builder( sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_dtype)
sample_to_metric_fname,
metric_dtype)
metric_to_sample_fname = f"{metric_save_path}/{metric_name}_metric_to_sample" metric_to_sample_fname = f"{metric_save_path}/{metric_name}_metric_to_sample"
os.system(f"rm -rf {metric_to_sample_fname}*") os.system(f"rm -rf {metric_to_sample_fname}*")
metric_to_sample_dict = defaultdict(list) metric_to_sample_dict = defaultdict(list)
...@@ -84,34 +81,25 @@ class DataAnalyzer(object): ...@@ -84,34 +81,25 @@ class DataAnalyzer(object):
elif metric_type == 'accumulate_value_over_samples': elif metric_type == 'accumulate_value_over_samples':
metric_value = None metric_value = None
metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value" metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
metric_results.append({ metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname})
"metric_value": metric_value,
"metric_value_fname": metric_value_fname
})
return metric_results return metric_results
def update_metric_results(self, def update_metric_results(self, data, metric_types, metric_functions, metric_results):
data,
metric_types,
metric_functions,
metric_results):
for m_idx in range(len(metric_types)): for m_idx in range(len(metric_types)):
metric_type, metric_function, metric_result = metric_types[m_idx], \ metric_type, metric_function, metric_result = metric_types[m_idx], \
metric_functions[m_idx], metric_results[m_idx] metric_functions[m_idx], metric_results[m_idx]
if metric_type == 'single_value_per_sample': if metric_type == 'single_value_per_sample':
metric_values = metric_function(data) metric_values = metric_function(data)
for row in range(metric_values.size()[0]): for row in range(metric_values.size()[0]):
metric_result["sample_to_metric_builder"].add_item( metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1))
metric_values[row].reshape(-1)) metric_result["metric_to_sample_dict"][metric_values[row].item()].append(
metric_result["metric_to_sample_dict"][ data['index'][row][0].item())
metric_values[row].item()].append(data['index'][row][0].item())
for m_value in metric_result["metric_to_sample_dict"]: for m_value in metric_result["metric_to_sample_dict"]:
if len(metric_result["metric_to_sample_dict"][m_value]) > 100: if len(metric_result["metric_to_sample_dict"][m_value]) > 100:
metric_fname = metric_result["metric_to_sample_fname"] metric_fname = metric_result["metric_to_sample_fname"]
with open(f"{metric_fname}_{m_value}.csv", 'a') as f: with open(f"{metric_fname}_{m_value}.csv", 'a') as f:
writer = csv.writer(f) writer = csv.writer(f)
writer.writerows( writer.writerows([metric_result["metric_to_sample_dict"][m_value]])
[metric_result["metric_to_sample_dict"][m_value]])
metric_result["metric_to_sample_dict"][m_value] = [] metric_result["metric_to_sample_dict"][m_value] = []
elif metric_type == 'accumulate_value_over_samples': elif metric_type == 'accumulate_value_over_samples':
metric_values = metric_function(data) metric_values = metric_function(data)
...@@ -126,25 +114,20 @@ class DataAnalyzer(object): ...@@ -126,25 +114,20 @@ class DataAnalyzer(object):
metric_dtypes[m_idx], metric_results[m_idx] metric_dtypes[m_idx], metric_results[m_idx]
if metric_type == 'single_value_per_sample': if metric_type == 'single_value_per_sample':
metric_fname = metric_result["sample_to_metric_fname"] metric_fname = metric_result["sample_to_metric_fname"]
close_mmap_dataset_builder(metric_result["sample_to_metric_builder"], close_mmap_dataset_builder(metric_result["sample_to_metric_builder"], metric_fname)
metric_fname)
for m_value in metric_result["metric_to_sample_dict"]: for m_value in metric_result["metric_to_sample_dict"]:
if len(metric_result["metric_to_sample_dict"][m_value]) > 0: if len(metric_result["metric_to_sample_dict"][m_value]) > 0:
metric_fname = metric_result["metric_to_sample_fname"] metric_fname = metric_result["metric_to_sample_fname"]
with open(f"{metric_fname}_{m_value}.csv", 'a') as f: with open(f"{metric_fname}_{m_value}.csv", 'a') as f:
writer = csv.writer(f) writer = csv.writer(f)
writer.writerows( writer.writerows([metric_result["metric_to_sample_dict"][m_value]])
[metric_result["metric_to_sample_dict"][m_value]])
metric_result["metric_to_sample_dict"][m_value] = [] metric_result["metric_to_sample_dict"][m_value] = []
elif metric_type == 'accumulate_value_over_samples': elif metric_type == 'accumulate_value_over_samples':
if metric_result["metric_value"] is not None: if metric_result["metric_value"] is not None:
metric_value_builder = create_mmap_dataset_builder( metric_value_builder = create_mmap_dataset_builder(metric_result["metric_value_fname"],
metric_result["metric_value_fname"], metric_dtype)
metric_dtype) metric_value_builder.add_item(metric_result["metric_value"].reshape(-1))
metric_value_builder.add_item( close_mmap_dataset_builder(metric_value_builder, metric_result["metric_value_fname"])
metric_result["metric_value"].reshape(-1))
close_mmap_dataset_builder(metric_value_builder,
metric_result["metric_value_fname"])
def run_map_helper(self, thread_id): def run_map_helper(self, thread_id):
start_idx, end_idx = self.thread_splits[thread_id][0], \ start_idx, end_idx = self.thread_splits[thread_id][0], \
...@@ -152,15 +135,9 @@ class DataAnalyzer(object): ...@@ -152,15 +135,9 @@ class DataAnalyzer(object):
logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \ logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \
f"on data subset {start_idx} to {end_idx}") f"on data subset {start_idx} to {end_idx}")
thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx))) thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx)))
sampler = BatchSampler(SequentialSampler(thread_dataset), sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False)
batch_size=self.batch_size,
drop_last=False)
if self.collate_fn is None: if self.collate_fn is None:
iterator = iter( iterator = iter(DataLoader(thread_dataset, batch_sampler=sampler, num_workers=0, pin_memory=False))
DataLoader(thread_dataset,
batch_sampler=sampler,
num_workers=0,
pin_memory=False))
else: else:
iterator = iter( iterator = iter(
DataLoader(thread_dataset, DataLoader(thread_dataset,
...@@ -169,19 +146,11 @@ class DataAnalyzer(object): ...@@ -169,19 +146,11 @@ class DataAnalyzer(object):
collate_fn=self.collate_fn, collate_fn=self.collate_fn,
pin_memory=False)) pin_memory=False))
if self.custom_map_init is None: if self.custom_map_init is None:
metric_results = self.init_metric_results(thread_id, metric_results = self.init_metric_results(thread_id, self.metric_names, self.metric_types,
self.metric_names, self.metric_dtypes, self.save_path, self.worker_id)
self.metric_types,
self.metric_dtypes,
self.save_path,
self.worker_id)
else: else:
metric_results = self.custom_map_init(thread_id, metric_results = self.custom_map_init(thread_id, self.metric_names, self.metric_types, self.metric_dtypes,
self.metric_names, self.save_path, self.worker_id)
self.metric_types,
self.metric_dtypes,
self.save_path,
self.worker_id)
total_sample = len(thread_dataset) total_sample = len(thread_dataset)
processed_sample = 0 processed_sample = 0
start = time.time() start = time.time()
...@@ -189,15 +158,9 @@ class DataAnalyzer(object): ...@@ -189,15 +158,9 @@ class DataAnalyzer(object):
try: try:
data = next(iterator) data = next(iterator)
if self.custom_map_update is None: if self.custom_map_update is None:
self.update_metric_results(data, self.update_metric_results(data, self.metric_types, self.metric_functions, metric_results)
self.metric_types,
self.metric_functions,
metric_results)
else: else:
self.custom_map_update(data, self.custom_map_update(data, self.metric_types, self.metric_functions, metric_results)
self.metric_types,
self.metric_functions,
metric_results)
processed_sample += self.batch_size processed_sample += self.batch_size
duration = (time.time() - start) / 3600.0 duration = (time.time() - start) / 3600.0
remain_duration = duration * total_sample / processed_sample - duration remain_duration = duration * total_sample / processed_sample - duration
...@@ -206,22 +169,17 @@ class DataAnalyzer(object): ...@@ -206,22 +169,17 @@ class DataAnalyzer(object):
f"out of {total_sample} processed in {duration:.2f} hr, " \ f"out of {total_sample} processed in {duration:.2f} hr, " \
f"estimated to finish in {remain_duration:.2f} hr") f"estimated to finish in {remain_duration:.2f} hr")
except StopIteration: except StopIteration:
logger.info( logger.info(f"worker {self.worker_id} thread {thread_id}: reach end of file")
f"worker {self.worker_id} thread {thread_id}: reach end of file")
break break
if self.custom_map_finalize is None: if self.custom_map_finalize is None:
self.finalize_metric_results(self.metric_types, self.finalize_metric_results(self.metric_types, self.metric_dtypes, metric_results)
self.metric_dtypes,
metric_results)
else: else:
self.custom_map_finalize(self.metric_types, self.custom_map_finalize(self.metric_types, self.metric_dtypes, metric_results)
self.metric_dtypes,
metric_results)
logger.info(f"worker {self.worker_id} thread {thread_id}: finished") logger.info(f"worker {self.worker_id} thread {thread_id}: finished")
def run_map(self): def run_map(self):
self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.num_workers, self.worker_id,
self.num_workers, self.worker_id, self.num_threads) self.num_threads)
if len(self.specific_threads) > 0: if len(self.specific_threads) > 0:
threads_to_run = self.specific_threads threads_to_run = self.specific_threads
else: else:
...@@ -238,81 +196,50 @@ class DataAnalyzer(object): ...@@ -238,81 +196,50 @@ class DataAnalyzer(object):
assert self.num_threads == 1 assert self.num_threads == 1
self.run_map_helper(0) self.run_map_helper(0)
def get_metric_value_percentiles(self, def get_metric_value_percentiles(self, metric_name, num_sample_per_value, total_num_samples):
metric_name,
num_sample_per_value,
total_num_samples):
logger.info(f"Checking the value percentiles of metric {metric_name}...") logger.info(f"Checking the value percentiles of metric {metric_name}...")
processed_samples = 0 processed_samples = 0
current_percentile = 5 current_percentile = 5
for key in sorted(num_sample_per_value.keys()): for key in sorted(num_sample_per_value.keys()):
processed_samples += num_sample_per_value[key] processed_samples += num_sample_per_value[key]
if processed_samples >= total_num_samples * current_percentile / 100.0: if processed_samples >= total_num_samples * current_percentile / 100.0:
logger.info( logger.info(f"Metric {metric_name} {current_percentile}th percentile: {key}")
f"Metric {metric_name} {current_percentile}th percentile: {key}")
current_percentile += 5 current_percentile += 5
def merge_gather_map_stats(self, def merge_gather_map_stats(self, num_workers, num_threads, num_threads_reduce, t_idx_reduce, metric_save_path,
num_workers, metric_name, return_dict):
num_threads,
num_threads_reduce,
t_idx_reduce,
metric_save_path,
metric_name,
return_dict):
results = [] results = []
for w_idx in range(num_workers): for w_idx in range(num_workers):
for t_idx in range(num_threads): for t_idx in range(num_threads):
if (w_idx * num_threads + t_idx) % num_threads_reduce == t_idx_reduce: if (w_idx * num_threads + t_idx) % num_threads_reduce == t_idx_reduce:
w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/" w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric" w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric"
w_sample_to_metric = MMapIndexedDataset(w_sample_to_metric_fname, w_sample_to_metric = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True)
skip_warmup=True)
unique_v = list(np.unique(w_sample_to_metric)) unique_v = list(np.unique(w_sample_to_metric))
sample_to_metric_count = len(w_sample_to_metric) sample_to_metric_count = len(w_sample_to_metric)
logger.info( logger.info(f"Finished gathering map stats from worker {w_idx} thread {t_idx}.")
f"Finished gathering map stats from worker {w_idx} thread {t_idx}."
)
results.append([unique_v, sample_to_metric_count]) results.append([unique_v, sample_to_metric_count])
return_dict[t_idx_reduce] = results return_dict[t_idx_reduce] = results
def merge_sample_to_metric(self, def merge_sample_to_metric(self, t_idx_reduce, metric_save_path, metric_name, metric_value_dtype,
t_idx_reduce,
metric_save_path,
metric_name,
metric_value_dtype,
map_worker_thread): map_worker_thread):
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}" sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}"
sample_to_metric_builder = create_mmap_dataset_builder( sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_value_dtype)
sample_to_metric_fname,
metric_value_dtype)
for w_t in map_worker_thread: for w_t in map_worker_thread:
w_metric_save_path = f"{metric_save_path}/worker{w_t[0]}_thread{w_t[1]}/" w_metric_save_path = f"{metric_save_path}/worker{w_t[0]}_thread{w_t[1]}/"
w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric" w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric"
w_data = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True) w_data = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True)
for row in range(len(w_data)): for row in range(len(w_data)):
sample_to_metric_builder.add_item( sample_to_metric_builder.add_item(torch.tensor(w_data[row].astype(np.int64), dtype=torch.long))
torch.tensor(w_data[row].astype(np.int64), logger.info(f"Finished merge_sample_to_metric from worker {w_t[0]} thread {w_t[1]}.")
dtype=torch.long))
logger.info(
f"Finished merge_sample_to_metric from worker {w_t[0]} thread {w_t[1]}.")
close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname) close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname)
def merge_metric_to_sample(self, def merge_metric_to_sample(self, t_idx_reduce, metric_save_path, metric_name, sample_idx_dtype, metric_value_dtype,
t_idx_reduce, unique_metric_values, num_workers, num_threads):
metric_save_path,
metric_name,
sample_idx_dtype,
metric_value_dtype,
unique_metric_values,
num_workers,
num_threads):
index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}" index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}"
index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, sample_idx_dtype)
sample_idx_dtype)
index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}" index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}"
index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, metric_value_dtype)
metric_value_dtype)
for unique_v in unique_metric_values: for unique_v in unique_metric_values:
samples = [] samples = []
for w_idx in range(num_workers): for w_idx in range(num_workers):
...@@ -330,13 +257,7 @@ class DataAnalyzer(object): ...@@ -330,13 +257,7 @@ class DataAnalyzer(object):
close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname) close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname) close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
def merge_map_results(self, def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_workers, num_threads,
dataset,
metric_names,
metric_types,
save_path,
num_workers,
num_threads,
num_threads_reduce): num_threads_reduce):
total_num_samples = len(dataset) total_num_samples = len(dataset)
sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1) sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1)
...@@ -385,9 +306,7 @@ class DataAnalyzer(object): ...@@ -385,9 +306,7 @@ class DataAnalyzer(object):
for w_idx in range(num_workers): for w_idx in range(num_workers):
for t_idx in range(num_threads): for t_idx in range(num_threads):
map_worker_thread.append([w_idx, t_idx]) map_worker_thread.append([w_idx, t_idx])
thread_splits = split_index(0, thread_splits = split_index(0, len(map_worker_thread), num_threads_reduce)
len(map_worker_thread),
num_threads_reduce)
p = [] p = []
for t_idx_reduce in range(num_threads_reduce): for t_idx_reduce in range(num_threads_reduce):
start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1] start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
...@@ -405,24 +324,18 @@ class DataAnalyzer(object): ...@@ -405,24 +324,18 @@ class DataAnalyzer(object):
p[t_idx_reduce].join() p[t_idx_reduce].join()
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric" sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
sample_to_metric_builder = create_mmap_dataset_builder( sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_value_dtype)
sample_to_metric_fname,
metric_value_dtype)
for t_idx_reduce in range(num_threads_reduce): for t_idx_reduce in range(num_threads_reduce):
chunk_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}" chunk_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}"
logger.info(f"Merging file {chunk_fname}") logger.info(f"Merging file {chunk_fname}")
sample_to_metric_builder.merge_file_(chunk_fname) sample_to_metric_builder.merge_file_(chunk_fname)
close_mmap_dataset_builder(sample_to_metric_builder, close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname)
sample_to_metric_fname) sample_to_metric = MMapIndexedDataset(sample_to_metric_fname, skip_warmup=True)
sample_to_metric = MMapIndexedDataset(sample_to_metric_fname,
skip_warmup=True)
assert len(sample_to_metric) == total_num_samples assert len(sample_to_metric) == total_num_samples
# metric_to_sample # metric_to_sample
unique_metric_values = list(sorted(unique_metric_values)) unique_metric_values = list(sorted(unique_metric_values))
thread_splits = split_index(0, thread_splits = split_index(0, len(unique_metric_values), num_threads_reduce)
len(unique_metric_values),
num_threads_reduce)
p = [] p = []
for t_idx_reduce in range(num_threads_reduce): for t_idx_reduce in range(num_threads_reduce):
start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1] start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
...@@ -442,13 +355,9 @@ class DataAnalyzer(object): ...@@ -442,13 +355,9 @@ class DataAnalyzer(object):
for t_idx_reduce in range(num_threads_reduce): for t_idx_reduce in range(num_threads_reduce):
p[t_idx_reduce].join() p[t_idx_reduce].join()
index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample" index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample"
index_to_sample_builder = create_mmap_dataset_builder( index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, sample_idx_dtype)
index_to_sample_fname,
sample_idx_dtype)
index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric" index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric"
index_to_metric_builder = create_mmap_dataset_builder( index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, metric_value_dtype)
index_to_metric_fname,
metric_value_dtype)
for t_idx_reduce in range(num_threads_reduce): for t_idx_reduce in range(num_threads_reduce):
chunk_is_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}" chunk_is_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}"
logger.info(f"Merging file {chunk_is_fname}") logger.info(f"Merging file {chunk_is_fname}")
...@@ -456,43 +365,29 @@ class DataAnalyzer(object): ...@@ -456,43 +365,29 @@ class DataAnalyzer(object):
chunk_im_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}" chunk_im_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}"
logger.info(f"Merging file {chunk_im_fname}") logger.info(f"Merging file {chunk_im_fname}")
index_to_metric_builder.merge_file_(chunk_im_fname) index_to_metric_builder.merge_file_(chunk_im_fname)
close_mmap_dataset_builder(index_to_sample_builder, close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
index_to_sample_fname) close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
close_mmap_dataset_builder(index_to_metric_builder,
index_to_metric_fname)
num_sample_per_value = {} num_sample_per_value = {}
index_to_sample = MMapIndexedDataset(index_to_sample_fname, index_to_sample = MMapIndexedDataset(index_to_sample_fname, skip_warmup=True)
skip_warmup=True) index_to_metric = MMapIndexedDataset(index_to_metric_fname, skip_warmup=True)
index_to_metric = MMapIndexedDataset(index_to_metric_fname,
skip_warmup=True)
index_to_sample_merged_fname = f"{metric_save_path}/{metric_name}_index_to_sample_percentile_merged" index_to_sample_merged_fname = f"{metric_save_path}/{metric_name}_index_to_sample_percentile_merged"
index_to_sample_merged_builder = create_mmap_dataset_builder( index_to_sample_merged_builder = create_mmap_dataset_builder(index_to_sample_merged_fname,
index_to_sample_merged_fname, sample_idx_dtype)
sample_idx_dtype)
for v_idx in range(len(index_to_sample)): for v_idx in range(len(index_to_sample)):
if v_idx > 0: if v_idx > 0:
assert index_to_metric[v_idx] > index_to_metric[v_idx - 1] assert index_to_metric[v_idx] > index_to_metric[v_idx - 1]
num_sample_per_value[index_to_metric[v_idx][0]] = len( num_sample_per_value[index_to_metric[v_idx][0]] = len(index_to_sample[v_idx])
index_to_sample[v_idx])
assert sum(num_sample_per_value.values()) == total_num_samples assert sum(num_sample_per_value.values()) == total_num_samples
merge_step = len(index_to_sample) // 100 merge_step = len(index_to_sample) // 100
for v_idx in range(0, len(index_to_sample), merge_step): for v_idx in range(0, len(index_to_sample), merge_step):
merged_samples = np.copy( merged_samples = np.copy(
np.concatenate( np.concatenate(index_to_sample[v_idx:min(len(index_to_sample), (v_idx + merge_step))],
index_to_sample[v_idx:min(len(index_to_sample), axis=None))
(v_idx + merge_step))],
axis=None))
index_to_sample_merged_builder.add_item( index_to_sample_merged_builder.add_item(
torch.tensor(merged_samples.astype(np.int64), torch.tensor(merged_samples.astype(np.int64), dtype=torch.long))
dtype=torch.long)) logger.info(f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}.")
logger.info( close_mmap_dataset_builder(index_to_sample_merged_builder, index_to_sample_merged_fname)
f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}." self.get_metric_value_percentiles(metric_name, num_sample_per_value, total_num_samples)
)
close_mmap_dataset_builder(index_to_sample_merged_builder,
index_to_sample_merged_fname)
self.get_metric_value_percentiles(metric_name,
num_sample_per_value,
total_num_samples)
elif metric_type == 'accumulate_value_over_samples': elif metric_type == 'accumulate_value_over_samples':
metric_save_path = f"{save_path}/{metric_name}/" metric_save_path = f"{save_path}/{metric_name}/"
metric_value = None metric_value = None
...@@ -500,8 +395,7 @@ class DataAnalyzer(object): ...@@ -500,8 +395,7 @@ class DataAnalyzer(object):
for t_idx in range(num_threads): for t_idx in range(num_threads):
w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/" w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
w_metric_value_fname = f"{w_metric_save_path}/{metric_name}_metric_value" w_metric_value_fname = f"{w_metric_save_path}/{metric_name}_metric_value"
w_metric_value = MMapIndexedDataset(w_metric_value_fname, w_metric_value = MMapIndexedDataset(w_metric_value_fname, skip_warmup=True)
skip_warmup=True)
if metric_value is None: if metric_value is None:
metric_value = np.copy(w_metric_value[0]) metric_value = np.copy(w_metric_value[0])
else: else:
...@@ -510,28 +404,14 @@ class DataAnalyzer(object): ...@@ -510,28 +404,14 @@ class DataAnalyzer(object):
value_min = int(min(metric_value)) value_min = int(min(metric_value))
metric_value_dtype = find_fit_int_dtype(value_min, value_max) metric_value_dtype = find_fit_int_dtype(value_min, value_max)
metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value" metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
metric_value_builder = create_mmap_dataset_builder( metric_value_builder = create_mmap_dataset_builder(metric_value_fname, metric_value_dtype)
metric_value_fname, metric_value_builder.add_item(torch.tensor(metric_value.astype(np.int64), dtype=torch.long))
metric_value_dtype)
metric_value_builder.add_item(
torch.tensor(metric_value.astype(np.int64),
dtype=torch.long))
close_mmap_dataset_builder(metric_value_builder, metric_value_fname) close_mmap_dataset_builder(metric_value_builder, metric_value_fname)
def run_reduce(self): def run_reduce(self):
if self.custom_reduce is None: if self.custom_reduce is None:
self.merge_map_results(self.dataset, self.merge_map_results(self.dataset, self.metric_names, self.metric_types, self.save_path,
self.metric_names, self.num_workers, self.num_threads, self.num_threads_reduce)
self.metric_types,
self.save_path,
self.num_workers,
self.num_threads,
self.num_threads_reduce)
else: else:
self.custom_reduce(self.dataset, self.custom_reduce(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers,
self.metric_names, self.num_threads, self.num_threads_reduce)
self.metric_types,
self.save_path,
self.num_workers,
self.num_threads,
self.num_threads_reduce)
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
coding=utf-8
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Part of this code was adopted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/data_samplers.py Part of this code was adopted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/data_samplers.py
''' """
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
import os import os
...@@ -31,6 +34,7 @@ from .utils import create_mmap_dataset_builder, close_mmap_dataset_builder, find ...@@ -31,6 +34,7 @@ from .utils import create_mmap_dataset_builder, close_mmap_dataset_builder, find
class DeepSpeedDataSampler(object): class DeepSpeedDataSampler(object):
def __init__(self, def __init__(self,
data_efficiency_config, data_efficiency_config,
one_epoch_total_samples, one_epoch_total_samples,
...@@ -45,8 +49,8 @@ class DeepSpeedDataSampler(object): ...@@ -45,8 +49,8 @@ class DeepSpeedDataSampler(object):
self.data_efficiency_config = data_efficiency_config self.data_efficiency_config = data_efficiency_config
self.one_epoch_total_samples = one_epoch_total_samples self.one_epoch_total_samples = one_epoch_total_samples
self.index_dtype = find_fit_int_dtype(0, one_epoch_total_samples) self.index_dtype = find_fit_int_dtype(0, one_epoch_total_samples)
self.total_samples = one_epoch_total_samples * self.data_efficiency_config[ self.total_samples = one_epoch_total_samples * self.data_efficiency_config[DATA_SAMPLING][
DATA_SAMPLING][DATA_SAMPLING_NUM_EPOCHS] DATA_SAMPLING_NUM_EPOCHS]
self.micro_batch_size = micro_batch_size self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank self.data_parallel_rank = data_parallel_rank
self.data_parallel_group = data_parallel_group self.data_parallel_group = data_parallel_group
...@@ -57,13 +61,11 @@ class DeepSpeedDataSampler(object): ...@@ -57,13 +61,11 @@ class DeepSpeedDataSampler(object):
self.gradient_accumulation_steps self.gradient_accumulation_steps
self.global_rank = global_rank self.global_rank = global_rank
self.drop_last = drop_last self.drop_last = drop_last
self.np_rng = np.random.default_rng( self.np_rng = np.random.default_rng(self.data_efficiency_config[DATA_EFFICIENCY_SEED])
self.data_efficiency_config[DATA_EFFICIENCY_SEED])
self.state = {} self.state = {}
self.batch = [] self.batch = []
self.consumed_samples = 0 self.consumed_samples = 0
if self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ if self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]:
CURRICULUM_LEARNING_ENABLED]:
self.curriculum_step = 0 self.curriculum_step = 0
self.current_difficulties = {} self.current_difficulties = {}
self.data_cluster_paths = [] self.data_cluster_paths = []
...@@ -77,33 +79,26 @@ class DeepSpeedDataSampler(object): ...@@ -77,33 +79,26 @@ class DeepSpeedDataSampler(object):
if self.global_rank == 0: if self.global_rank == 0:
self.data_clusters = [] self.data_clusters = []
self.data_cluster_sizes = [] self.data_cluster_sizes = []
cluster_path = self.data_efficiency_config[DATA_SAMPLING][ cluster_path = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING][CURRICULUM_LEARNING_CLUSTER_PATH] CURRICULUM_LEARNING_CLUSTER_PATH]
if not os.path.exists(cluster_path): if not os.path.exists(cluster_path):
os.makedirs(cluster_path) os.makedirs(cluster_path)
for metric in self.data_efficiency_config[DATA_SAMPLING][ for metric in self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]:
CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]:
self.curriculum_schedulers[metric] = CurriculumScheduler( self.curriculum_schedulers[metric] = CurriculumScheduler(
data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING] data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS][metric])
[CURRICULUM_LEARNING_METRICS][metric]) self.difficulty_type[metric] = data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
self.difficulty_type[metric] = data_efficiency_config[DATA_SAMPLING][ CURRICULUM_LEARNING_METRICS][metric][CURRICULUM_LEARNING_DIFFICULTY_TYPE]
CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS][metric][ self.clustering_type[metric] = data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING_DIFFICULTY_TYPE] CURRICULUM_LEARNING_METRICS][metric][CURRICULUM_LEARNING_CLUSTERING_TYPE]
self.clustering_type[metric] = data_efficiency_config[DATA_SAMPLING][
CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS][metric][
CURRICULUM_LEARNING_CLUSTERING_TYPE]
if self.global_rank == 0: if self.global_rank == 0:
if self.clustering_type[metric] != CURRICULUM_LEARNING_SINGLE_CLUSTER: if self.clustering_type[metric] != CURRICULUM_LEARNING_SINGLE_CLUSTER:
self.curriculum_index_to_sample[metric] = MMapIndexedDataset( self.curriculum_index_to_sample[metric] = MMapIndexedDataset(
data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING] data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]
[CURRICULUM_LEARNING_METRICS][metric] [metric][CURRICULUM_LEARNING_SAMPLE_PATH],
[CURRICULUM_LEARNING_SAMPLE_PATH],
skip_warmup=True) skip_warmup=True)
if self.difficulty_type[ if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED:
metric] == CURRICULUM_LEARNING_VALUE_BASED:
self.curriculum_index_to_metric[metric] = MMapIndexedDataset( self.curriculum_index_to_metric[metric] = MMapIndexedDataset(
data_efficiency_config[DATA_SAMPLING] data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]
[CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS]
[metric][CURRICULUM_LEARNING_METRIC_PATH], [metric][CURRICULUM_LEARNING_METRIC_PATH],
skip_warmup=True) skip_warmup=True)
...@@ -122,8 +117,7 @@ class DeepSpeedDataSampler(object): ...@@ -122,8 +117,7 @@ class DeepSpeedDataSampler(object):
def set_custom_curriculum_learning_schedule(self, schedule_func_dict): def set_custom_curriculum_learning_schedule(self, schedule_func_dict):
for metric in self.curriculum_schedulers: for metric in self.curriculum_schedulers:
if metric in schedule_func_dict: if metric in schedule_func_dict:
self.curriculum_schedulers[metric].set_custom_get_difficulty( self.curriculum_schedulers[metric].set_custom_get_difficulty(schedule_func_dict[metric])
schedule_func_dict[metric])
def get_start_end_idx(self): def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size start_idx = self.data_parallel_rank * self.micro_batch_size
...@@ -133,26 +127,19 @@ class DeepSpeedDataSampler(object): ...@@ -133,26 +127,19 @@ class DeepSpeedDataSampler(object):
def get_sample_based_on_metric_value(self, metric, value_start, value_end): def get_sample_based_on_metric_value(self, metric, value_start, value_end):
new_samples = None new_samples = None
for row in range(len(self.curriculum_index_to_sample[metric])): for row in range(len(self.curriculum_index_to_sample[metric])):
if self.curriculum_index_to_metric[metric][ if self.curriculum_index_to_metric[metric][row] <= value_end and self.curriculum_index_to_metric[metric][
row] <= value_end and self.curriculum_index_to_metric[metric][ row] > value_start:
row] > value_start:
row_samples = np.copy(self.curriculum_index_to_sample[metric][row]) row_samples = np.copy(self.curriculum_index_to_sample[metric][row])
new_samples = row_samples if new_samples is None else np.concatenate( new_samples = row_samples if new_samples is None else np.concatenate(
(new_samples, (new_samples, row_samples), axis=None)
row_samples),
axis=None)
return new_samples return new_samples
def get_sample_based_on_metric_percentile(self, def get_sample_based_on_metric_percentile(self, metric, percentile_start, percentile_end):
metric,
percentile_start,
percentile_end):
new_samples = None new_samples = None
if self.data_1epoch_size is None: if self.data_1epoch_size is None:
self.data_1epoch_size = sum( self.data_1epoch_size = sum(len(x) for x in self.curriculum_index_to_sample[metric])
len(x) for x in self.curriculum_index_to_sample[metric]) max_percentile = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_METRICS][
max_percentile = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ metric][CURRICULUM_LEARNING_MAX_DIFFICULTY]
CURRICULUM_LEARNING_METRICS][metric][CURRICULUM_LEARNING_MAX_DIFFICULTY]
sample_per_percentile = self.data_1epoch_size // max_percentile sample_per_percentile = self.data_1epoch_size // max_percentile
start_count = sample_per_percentile * percentile_start start_count = sample_per_percentile * percentile_start
end_count = sample_per_percentile * percentile_end end_count = sample_per_percentile * percentile_end
...@@ -167,12 +154,9 @@ class DeepSpeedDataSampler(object): ...@@ -167,12 +154,9 @@ class DeepSpeedDataSampler(object):
row_end = row_size row_end = row_size
else: else:
row_end = end_count - current_count row_end = end_count - current_count
row_samples = np.copy( row_samples = np.copy(self.curriculum_index_to_sample[metric][row][row_start:row_end])
self.curriculum_index_to_sample[metric][row][row_start:row_end])
new_samples = row_samples if new_samples is None else np.concatenate( new_samples = row_samples if new_samples is None else np.concatenate(
(new_samples, (new_samples, row_samples), axis=None)
row_samples),
axis=None)
current_count += row_size current_count += row_size
if current_count >= end_count: if current_count >= end_count:
break break
...@@ -193,63 +177,42 @@ class DeepSpeedDataSampler(object): ...@@ -193,63 +177,42 @@ class DeepSpeedDataSampler(object):
need_clustering += 1 need_clustering += 1
if need_clustering > 1: if need_clustering > 1:
for metric in self.curriculum_schedulers: for metric in self.curriculum_schedulers:
if self.clustering_type[ if self.clustering_type[metric] == CURRICULUM_LEARNING_SINGLE_CLUSTER:
metric] == CURRICULUM_LEARNING_SINGLE_CLUSTER:
metric_cluster = np.arange(start=0, metric_cluster = np.arange(start=0,
stop=self.one_epoch_total_samples, stop=self.one_epoch_total_samples,
step=1, step=1,
dtype=self.index_dtype) dtype=self.index_dtype)
else: else:
if self.difficulty_type[ if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED:
metric] == CURRICULUM_LEARNING_VALUE_BASED: metric_cluster = self.get_sample_based_on_metric_value(metric, float('-inf'),
metric_cluster = self.get_sample_based_on_metric_value( self.current_difficulties[metric])
metric, elif self.difficulty_type[metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
float('-inf'),
self.current_difficulties[metric])
elif self.difficulty_type[
metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
metric_cluster = self.get_sample_based_on_metric_percentile( metric_cluster = self.get_sample_based_on_metric_percentile(
metric, metric, 0, self.current_difficulties[metric])
0,
self.current_difficulties[metric])
new_cluster = metric_cluster if new_cluster is None else \ new_cluster = metric_cluster if new_cluster is None else \
np.intersect1d(new_cluster, metric_cluster, assume_unique=True) np.intersect1d(new_cluster, metric_cluster, assume_unique=True)
for cluster in self.data_clusters: for cluster in self.data_clusters:
new_cluster = np.setdiff1d(new_cluster, new_cluster = np.setdiff1d(new_cluster, cluster[0], assume_unique=True)
cluster[0],
assume_unique=True)
else: else:
if len(self.data_clusters) == 0: if len(self.data_clusters) == 0:
new_cluster = np.arange(start=0, new_cluster = np.arange(start=0, stop=self.one_epoch_total_samples, step=1, dtype=self.index_dtype)
stop=self.one_epoch_total_samples,
step=1,
dtype=self.index_dtype)
for metric in self.curriculum_schedulers: for metric in self.curriculum_schedulers:
if self.clustering_type[metric] != CURRICULUM_LEARNING_SINGLE_CLUSTER: if self.clustering_type[metric] != CURRICULUM_LEARNING_SINGLE_CLUSTER:
if self.difficulty_type[ if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED:
metric] == CURRICULUM_LEARNING_VALUE_BASED: new_cluster = self.get_sample_based_on_metric_value(metric, previous_difficulties[metric],
new_cluster = self.get_sample_based_on_metric_value( self.current_difficulties[metric])
metric, elif self.difficulty_type[metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
previous_difficulties[metric],
self.current_difficulties[metric])
elif self.difficulty_type[
metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
new_cluster = self.get_sample_based_on_metric_percentile( new_cluster = self.get_sample_based_on_metric_percentile(
metric, metric, previous_difficulties[metric], self.current_difficulties[metric])
previous_difficulties[metric],
self.current_difficulties[metric])
if new_cluster is not None and len(new_cluster) > 0: if new_cluster is not None and len(new_cluster) > 0:
logger.info( logger.info(
f"new data cluster (previous_difficulties {previous_difficulties}, current_difficulties {self.current_difficulties}) with size {len(new_cluster)} generated." f"new data cluster (previous_difficulties {previous_difficulties}, current_difficulties {self.current_difficulties}) with size {len(new_cluster)} generated."
) )
self.np_rng.shuffle(new_cluster) self.np_rng.shuffle(new_cluster)
cluster_builder = create_mmap_dataset_builder(cluster_path, cluster_builder = create_mmap_dataset_builder(cluster_path, self.index_dtype)
self.index_dtype)
cluster_builder.add_item_numpy(new_cluster) cluster_builder.add_item_numpy(new_cluster)
close_mmap_dataset_builder(cluster_builder, cluster_path) close_mmap_dataset_builder(cluster_builder, cluster_path)
self.data_clusters.append( self.data_clusters.append(MMapIndexedDataset(cluster_path, skip_warmup=True))
MMapIndexedDataset(cluster_path,
skip_warmup=True))
self.data_cluster_sizes.append(len(self.data_clusters[-1][0])) self.data_cluster_sizes.append(len(self.data_clusters[-1][0]))
else: else:
logger.info( logger.info(
...@@ -264,10 +227,7 @@ class DeepSpeedDataSampler(object): ...@@ -264,10 +227,7 @@ class DeepSpeedDataSampler(object):
num_clusters = len(self.data_clusters) num_clusters = len(self.data_clusters)
weight_sum = sum(self.data_cluster_sizes) weight_sum = sum(self.data_cluster_sizes)
weights = [x / weight_sum for x in self.data_cluster_sizes] weights = [x / weight_sum for x in self.data_cluster_sizes]
samples = self.np_rng.choice(num_clusters, samples = self.np_rng.choice(num_clusters, self.global_batch_size, replace=True, p=weights)
self.global_batch_size,
replace=True,
p=weights)
samples = np.bincount(samples, minlength=num_clusters) samples = np.bincount(samples, minlength=num_clusters)
return samples return samples
...@@ -285,8 +245,7 @@ class DeepSpeedDataSampler(object): ...@@ -285,8 +245,7 @@ class DeepSpeedDataSampler(object):
def get_sample_from_cluster(self, cidx, num_samples): def get_sample_from_cluster(self, cidx, num_samples):
start_idx = self.data_cluster_current_position[cidx] start_idx = self.data_cluster_current_position[cidx]
samples = list( samples = list(np.copy(self.data_clusters[cidx][0][start_idx:(start_idx + num_samples)]))
np.copy(self.data_clusters[cidx][0][start_idx:(start_idx + num_samples)]))
self.data_cluster_current_position[cidx] += num_samples self.data_cluster_current_position[cidx] += num_samples
if len(samples) < num_samples: if len(samples) < num_samples:
num_samples_remained = num_samples - len(samples) num_samples_remained = num_samples - len(samples)
...@@ -297,14 +256,12 @@ class DeepSpeedDataSampler(object): ...@@ -297,14 +256,12 @@ class DeepSpeedDataSampler(object):
return samples return samples
def get_next_global_batch(self): def get_next_global_batch(self):
if self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ if self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]:
CURRICULUM_LEARNING_ENABLED]:
self.curriculum_step += 1 self.curriculum_step += 1
new_cluster = False new_cluster = False
previous_difficulties = {} previous_difficulties = {}
for metric in self.curriculum_schedulers: for metric in self.curriculum_schedulers:
next_difficulty = self.curriculum_schedulers[metric].update_difficulty( next_difficulty = self.curriculum_schedulers[metric].update_difficulty(self.curriculum_step)
self.curriculum_step)
if metric not in self.current_difficulties or \ if metric not in self.current_difficulties or \
next_difficulty != self.current_difficulties[metric]: next_difficulty != self.current_difficulties[metric]:
new_cluster = True new_cluster = True
...@@ -313,8 +270,7 @@ class DeepSpeedDataSampler(object): ...@@ -313,8 +270,7 @@ class DeepSpeedDataSampler(object):
else: else:
if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED: if self.difficulty_type[metric] == CURRICULUM_LEARNING_VALUE_BASED:
previous_difficulties[metric] = float('-inf') previous_difficulties[metric] = float('-inf')
elif self.difficulty_type[ elif self.difficulty_type[metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
metric] == CURRICULUM_LEARNING_PERCENTILE_BASED:
previous_difficulties[metric] = 0 previous_difficulties[metric] = 0
self.current_difficulties[metric] = next_difficulty self.current_difficulties[metric] = next_difficulty
if new_cluster: if new_cluster:
...@@ -323,12 +279,9 @@ class DeepSpeedDataSampler(object): ...@@ -323,12 +279,9 @@ class DeepSpeedDataSampler(object):
samples_per_cluster = self.sample_from_clusters() samples_per_cluster = self.sample_from_clusters()
batch = [] batch = []
for cidx in range(len(samples_per_cluster)): for cidx in range(len(samples_per_cluster)):
batch += self.get_sample_from_cluster(cidx, batch += self.get_sample_from_cluster(cidx, samples_per_cluster[cidx])
samples_per_cluster[cidx])
self.np_rng.shuffle(batch) self.np_rng.shuffle(batch)
batch = torch.tensor(batch, batch = torch.tensor(batch, device=get_accelerator().current_device_name(), dtype=torch.long).view(-1)
device=get_accelerator().current_device_name(),
dtype=torch.long).view(-1)
else: else:
batch = torch.empty(self.global_batch_size, batch = torch.empty(self.global_batch_size,
device=get_accelerator().current_device_name(), device=get_accelerator().current_device_name(),
...@@ -356,8 +309,7 @@ class DeepSpeedDataSampler(object): ...@@ -356,8 +309,7 @@ class DeepSpeedDataSampler(object):
CURRICULUM_LEARNING_STEP: self.curriculum_step, CURRICULUM_LEARNING_STEP: self.curriculum_step,
CURRICULUM_LEARNING_CURRENT_DIFFICULTIES: self.current_difficulties, CURRICULUM_LEARNING_CURRENT_DIFFICULTIES: self.current_difficulties,
CURRICULUM_LEARNING_DATA_CLUSTER_PATHS: self.data_cluster_paths, CURRICULUM_LEARNING_DATA_CLUSTER_PATHS: self.data_cluster_paths,
CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION: CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION: self.data_cluster_current_position,
self.data_cluster_current_position,
CURRICULUM_LEARNING_NP_RNG_STATE: np.random.get_state() CURRICULUM_LEARNING_NP_RNG_STATE: np.random.get_state()
} }
...@@ -367,11 +319,10 @@ class DeepSpeedDataSampler(object): ...@@ -367,11 +319,10 @@ class DeepSpeedDataSampler(object):
self.curriculum_step = state_dict[CURRICULUM_LEARNING_STEP] self.curriculum_step = state_dict[CURRICULUM_LEARNING_STEP]
self.current_difficulties = state_dict[CURRICULUM_LEARNING_CURRENT_DIFFICULTIES] self.current_difficulties = state_dict[CURRICULUM_LEARNING_CURRENT_DIFFICULTIES]
self.data_cluster_paths = state_dict[CURRICULUM_LEARNING_DATA_CLUSTER_PATHS] self.data_cluster_paths = state_dict[CURRICULUM_LEARNING_DATA_CLUSTER_PATHS]
self.data_cluster_current_position = state_dict[ self.data_cluster_current_position = state_dict[CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION]
CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION]
np.random.set_state(state_dict[CURRICULUM_LEARNING_NP_RNG_STATE]) np.random.set_state(state_dict[CURRICULUM_LEARNING_NP_RNG_STATE])
cluster_root_path = self.data_efficiency_config[DATA_SAMPLING][ cluster_root_path = self.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][
CURRICULUM_LEARNING][CURRICULUM_LEARNING_CLUSTER_PATH] CURRICULUM_LEARNING_CLUSTER_PATH]
# Backward compatibility: previously data_cluster_paths were stored as # Backward compatibility: previously data_cluster_paths were stored as
# absolute paths. Now we changed it to just the file name so that even # absolute paths. Now we changed it to just the file name so that even
# if user moved the cluster files, the checkpoint loading still works # if user moved the cluster files, the checkpoint loading still works
...@@ -379,12 +330,9 @@ class DeepSpeedDataSampler(object): ...@@ -379,12 +330,9 @@ class DeepSpeedDataSampler(object):
# in deepspeed json config. # in deepspeed json config.
for idx in range(len(self.data_cluster_paths)): for idx in range(len(self.data_cluster_paths)):
if '/' in self.data_cluster_paths[idx]: if '/' in self.data_cluster_paths[idx]:
self.data_cluster_paths[idx] = self.data_cluster_paths[idx].split( self.data_cluster_paths[idx] = self.data_cluster_paths[idx].split('/')[-1]
'/')[-1]
if self.global_rank == 0: if self.global_rank == 0:
for cluster_fname in self.data_cluster_paths: for cluster_fname in self.data_cluster_paths:
cluster_path = f"{cluster_root_path}/{cluster_fname}" cluster_path = f"{cluster_root_path}/{cluster_fname}"
self.data_clusters.append( self.data_clusters.append(MMapIndexedDataset(cluster_path, skip_warmup=True))
MMapIndexedDataset(cluster_path,
skip_warmup=True))
self.data_cluster_sizes.append(len(self.data_clusters[-1][0])) self.data_cluster_sizes.append(len(self.data_clusters[-1][0]))
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Part of this code was adopted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py Part of this code was adopted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
''' """
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
...@@ -50,16 +53,13 @@ def infer_dataset_impl(path): ...@@ -50,16 +53,13 @@ def infer_dataset_impl(path):
return None return None
else: else:
print(f"Dataset does not exist: {path}") print(f"Dataset does not exist: {path}")
print( print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
)
return None return None
def make_builder(out_file, impl, vocab_size=None): def make_builder(out_file, impl, vocab_size=None):
if impl == 'mmap': if impl == 'mmap':
return MMapIndexedDatasetBuilder(out_file, return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
dtype=__best_fitting_dtype(vocab_size))
else: else:
return IndexedDatasetBuilder(out_file) return IndexedDatasetBuilder(out_file)
...@@ -67,9 +67,7 @@ def make_builder(out_file, impl, vocab_size=None): ...@@ -67,9 +67,7 @@ def make_builder(out_file, impl, vocab_size=None):
def make_dataset(path, impl, skip_warmup=False): def make_dataset(path, impl, skip_warmup=False):
if not IndexedDataset.exists(path): if not IndexedDataset.exists(path):
print(f"Dataset does not exist: {path}") print(f"Dataset does not exist: {path}")
print( print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
)
return None return None
if impl == 'infer': if impl == 'infer':
impl = infer_dataset_impl(path) impl = infer_dataset_impl(path)
...@@ -150,10 +148,8 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -150,10 +148,8 @@ class IndexedDataset(torch.utils.data.Dataset):
def read_index(self, path): def read_index(self, path):
with open(index_file_path(path), 'rb') as f: with open(index_file_path(path), 'rb') as f:
magic = f.read(8) magic = f.read(8)
assert magic == self._HDR_MAGIC, ( assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. '
'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.')
'Make sure that --dataset-impl is configured properly.'
)
version = f.read(8) version = f.read(8)
assert struct.unpack('<Q', version) == (1, ) assert struct.unpack('<Q', version) == (1, )
code, self.element_size = struct.unpack('<QQ', f.read(16)) code, self.element_size = struct.unpack('<QQ', f.read(16))
...@@ -212,8 +208,7 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -212,8 +208,7 @@ class IndexedDataset(torch.utils.data.Dataset):
@staticmethod @staticmethod
def exists(path): def exists(path):
return (os.path.exists(index_file_path(path)) return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
and os.path.exists(data_file_path(path)))
@property @property
def supports_prefetch(self): def supports_prefetch(self):
...@@ -221,6 +216,7 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -221,6 +216,7 @@ class IndexedDataset(torch.utils.data.Dataset):
class IndexedCachedDataset(IndexedDataset): class IndexedCachedDataset(IndexedDataset):
def __init__(self, path): def __init__(self, path):
super().__init__(path) super().__init__(path)
self.cache = None self.cache = None
...@@ -273,15 +269,7 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -273,15 +269,7 @@ class IndexedCachedDataset(IndexedDataset):
class IndexedDatasetBuilder(object): class IndexedDatasetBuilder(object):
element_sizes = { element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float64: 4, np.double: 8}
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float64: 4,
np.double: 8
}
def __init__(self, out_file, dtype=np.int32): def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb') self.out_file = open(out_file, 'wb')
...@@ -379,12 +367,15 @@ def get_pointers_with_total(sizes, elemsize, dtype): ...@@ -379,12 +367,15 @@ def get_pointers_with_total(sizes, elemsize, dtype):
class MMapIndexedDataset(torch.utils.data.Dataset): class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object): class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00' _HDR_MAGIC = b'MMIDIDX\x00\x00'
@classmethod @classmethod
def writer(cls, path, dtype): def writer(cls, path, dtype):
class _Writer(object): class _Writer(object):
def __enter__(self): def __enter__(self):
self._file = open(path, 'wb') self._file = open(path, 'wb')
...@@ -430,10 +421,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -430,10 +421,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __init__(self, path, skip_warmup=False): def __init__(self, path, skip_warmup=False):
with open(path, 'rb') as stream: with open(path, 'rb') as stream:
magic_test = stream.read(9) magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, ( assert self._HDR_MAGIC == magic_test, ('Index file doesn\'t match expected format. '
'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.')
'Make sure that --dataset-impl is configured properly.'
)
version = struct.unpack('<Q', stream.read(8)) version = struct.unpack('<Q', stream.read(8))
assert (1, ) == version assert (1, ) == version
...@@ -452,10 +441,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -452,10 +441,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
print(" reading sizes...") print(" reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer, self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
dtype=np.int32,
count=self._len,
offset=offset)
print(" reading pointers...") print(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, self._pointers = np.frombuffer(self._bin_buffer,
dtype=np.int64, dtype=np.int64,
...@@ -465,8 +451,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -465,8 +451,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._doc_idx = np.frombuffer(self._bin_buffer, self._doc_idx = np.frombuffer(self._bin_buffer,
dtype=np.int64, dtype=np.int64,
count=self._doc_count, count=self._doc_count,
offset=offset + self._sizes.nbytes + offset=offset + self._sizes.nbytes + self._pointers.nbytes)
self._pointers.nbytes)
def __del__(self): def __del__(self):
self._bin_buffer_mmap._mmap.close() self._bin_buffer_mmap._mmap.close()
...@@ -514,9 +499,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -514,9 +499,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
print(" warming up data mmap file...") print(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path)) _warmup_mmap_file(data_file_path(self._path))
print(" creating numpy buffer of mmap...") print(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
mode='r',
order='C')
print(" creating memory view of numpy buffer...") print(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
...@@ -532,10 +515,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -532,10 +515,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, int): if isinstance(idx, int):
ptr, size = self._index[idx] ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
dtype=self._index.dtype,
count=size,
offset=ptr)
return np_array return np_array
elif isinstance(idx, slice): elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self)) start, stop, step = idx.indices(len(self))
...@@ -545,10 +525,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -545,10 +525,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
sizes = self._index._sizes[idx] sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes)) offsets = list(accumulate(sizes))
total_size = sum(sizes) total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
dtype=self._index.dtype,
count=total_size,
offset=ptr)
sents = np.split(np_array, offsets[:-1]) sents = np.split(np_array, offsets[:-1])
return sents return sents
...@@ -562,10 +539,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -562,10 +539,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
if length is None: if length is None:
length = size - offset length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr)
dtype=self._index.dtype,
count=length,
offset=ptr)
return np_array return np_array
@property @property
...@@ -591,8 +565,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -591,8 +565,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
@staticmethod @staticmethod
def exists(path): def exists(path):
return (os.path.exists(index_file_path(path)) return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
and os.path.exists(data_file_path(path)))
@property @property
def dtype(self): def dtype(self):
...@@ -600,6 +573,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -600,6 +573,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
class MMapIndexedDatasetBuilder(object): class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64): def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb') self._data_file = open(out_file, 'wb')
self._dtype = dtype self._dtype = dtype
...@@ -626,9 +600,7 @@ class MMapIndexedDatasetBuilder(object): ...@@ -626,9 +600,7 @@ class MMapIndexedDatasetBuilder(object):
assert index.dtype == self._dtype assert index.dtype == self._dtype
total_len = len(index.sizes) + len(self._sizes) total_len = len(index.sizes) + len(self._sizes)
print( print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}")
f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}"
)
offset = len(self._sizes) offset = len(self._sizes)
self._sizes.extend(index.sizes) self._sizes.extend(index.sizes)
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import math import math
import numpy as np import numpy as np
...@@ -32,19 +33,14 @@ def find_fit_int_dtype(min_value, max_value): ...@@ -32,19 +33,14 @@ def find_fit_int_dtype(min_value, max_value):
def split_index(start_idx, end_idx, num_partitions): def split_index(start_idx, end_idx, num_partitions):
partition_size = math.ceil((end_idx - start_idx) / num_partitions) partition_size = math.ceil((end_idx - start_idx) / num_partitions)
partitions = [[ partitions = [[start_idx + x * partition_size,
start_idx + x * partition_size, min(end_idx, start_idx + (x + 1) * partition_size)] for x in range(num_partitions)]
min(end_idx,
start_idx + (x + 1) * partition_size)
] for x in range(num_partitions)]
return partitions return partitions
def split_dataset(dataset, num_workers, worker_id, num_threads): def split_dataset(dataset, num_workers, worker_id, num_threads):
worker_splits = split_index(0, len(dataset), num_workers) worker_splits = split_index(0, len(dataset), num_workers)
thread_splits = split_index(worker_splits[worker_id][0], thread_splits = split_index(worker_splits[worker_id][0], worker_splits[worker_id][1], num_threads)
worker_splits[worker_id][1],
num_threads)
return worker_splits, thread_splits return worker_splits, thread_splits
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2019 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -14,6 +15,7 @@ from deepspeed.runtime.constants import GRADIENT_ACCUMULATION_STEPS, \ ...@@ -14,6 +15,7 @@ from deepspeed.runtime.constants import GRADIENT_ACCUMULATION_STEPS, \
class RepeatingLoader: class RepeatingLoader:
def __init__(self, loader): def __init__(self, loader):
"""Wraps an iterator to allow for infinite iteration. This is especially useful """Wraps an iterator to allow for infinite iteration. This is especially useful
for DataLoader types that we wish to automatically restart upon completion. for DataLoader types that we wish to automatically restart upon completion.
...@@ -37,6 +39,7 @@ class RepeatingLoader: ...@@ -37,6 +39,7 @@ class RepeatingLoader:
class DeepSpeedDataLoader(object): class DeepSpeedDataLoader(object):
def __init__(self, def __init__(self,
dataset, dataset,
batch_size, batch_size,
...@@ -55,30 +58,26 @@ class DeepSpeedDataLoader(object): ...@@ -55,30 +58,26 @@ class DeepSpeedDataLoader(object):
self.batch_size = batch_size self.batch_size = batch_size
self.curriculum_learning_enabled = False self.curriculum_learning_enabled = False
if CURRICULUM_LEARNING in deepspeed_dataloader_config: if CURRICULUM_LEARNING in deepspeed_dataloader_config:
self.curriculum_learning_enabled = deepspeed_dataloader_config[ self.curriculum_learning_enabled = deepspeed_dataloader_config[CURRICULUM_LEARNING]
CURRICULUM_LEARNING]
if self.curriculum_learning_enabled: if self.curriculum_learning_enabled:
data_sampler = DeepSpeedDataSampler( data_sampler = DeepSpeedDataSampler(self.deepspeed_dataloader_config[DATA_EFFICIENCY],
self.deepspeed_dataloader_config[DATA_EFFICIENCY], len(dataset),
len(dataset), self.batch_size,
self.batch_size, data_parallel_rank,
data_parallel_rank, data_parallel_world_size,
data_parallel_world_size, self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP],
self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP], self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS],
self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS], self.deepspeed_dataloader_config[GLOBAL_RANK],
self.deepspeed_dataloader_config[GLOBAL_RANK], drop_last=dataloader_drop_last)
drop_last=dataloader_drop_last)
device_count = get_accelerator().device_count() device_count = get_accelerator().device_count()
num_local_io_workers = self.deepspeed_dataloader_config[ num_local_io_workers = self.deepspeed_dataloader_config[DATA_SAMPLING_NUM_WORKERS]
DATA_SAMPLING_NUM_WORKERS]
else: else:
if local_rank >= 0: if local_rank >= 0:
if data_sampler is None: if data_sampler is None:
data_sampler = DistributedSampler( data_sampler = DistributedSampler(dataset=dataset,
dataset=dataset, num_replicas=data_parallel_world_size,
num_replicas=data_parallel_world_size, rank=data_parallel_rank)
rank=data_parallel_rank)
device_count = 1 device_count = 1
else: else:
if data_sampler is None: if data_sampler is None:
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
from deepspeed.utils import log_dist from deepspeed.utils import log_dist
...@@ -7,6 +10,7 @@ import logging ...@@ -7,6 +10,7 @@ import logging
class Eigenvalue(object): class Eigenvalue(object):
def __init__(self, def __init__(self,
verbose=False, verbose=False,
max_iter=100, max_iter=100,
...@@ -77,8 +81,7 @@ class Eigenvalue(object): ...@@ -77,8 +81,7 @@ class Eigenvalue(object):
] ]
else: else:
v = [ v = [
torch.randn(p.size(), torch.randn(p.size(), device=device) for p in model_block.parameters()
device=device) for p in model_block.parameters()
if p.grad is not None and p.grad.grad_fn is not None if p.grad is not None and p.grad.grad_fn is not None
] ]
torch.random.set_rng_state(rng_state) torch.random.set_rng_state(rng_state)
...@@ -100,24 +103,18 @@ class Eigenvalue(object): ...@@ -100,24 +103,18 @@ class Eigenvalue(object):
# Disable eigenvalue if the model doesn't support second order gradients computation, # Disable eigenvalue if the model doesn't support second order gradients computation,
# e.g. when enabling DS transformer kernel. # e.g. when enabling DS transformer kernel.
if len(grads) == 0 or len(params) == 0: if len(grads) == 0 or len(params) == 0:
log_dist(f'The model does NOT support eigenvalue computation.', log_dist(f'The model does NOT support eigenvalue computation.', ranks=[0], level=logging.WARNING)
ranks=[0],
level=logging.WARNING)
return [] return []
i = 0 i = 0
eigenvalue_current, eigenvalue_previous = 1., 0. eigenvalue_current, eigenvalue_previous = 1., 0.
while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs(
(eigenvalue_current - eigenvalue_previous) / (eigenvalue_current - eigenvalue_previous) / eigenvalue_current) >=
eigenvalue_current) >= self.tol): # test convergence criteria self.tol): # test convergence criteria
eigenvalue_previous = eigenvalue_current eigenvalue_previous = eigenvalue_current
Hv = torch.autograd.grad(grads, Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True)
params,
grad_outputs=v,
only_inputs=True,
retain_graph=True)
#Hv = [hv.float() for hv in Hv] #Hv = [hv.float() for hv in Hv]
Hv = [self.nan_to_num(hv).float() for hv in Hv] Hv = [self.nan_to_num(hv).float() for hv in Hv]
...@@ -131,9 +128,7 @@ class Eigenvalue(object): ...@@ -131,9 +128,7 @@ class Eigenvalue(object):
block_eigenvalue.append(eigenvalue_current) block_eigenvalue.append(eigenvalue_current)
if self.verbose: if self.verbose:
log_dist( log_dist(f'block: {block}, power iteration: {i}, eigenvalue: {eigenvalue_current}', ranks=[0])
f'block: {block}, power iteration: {i}, eigenvalue: {eigenvalue_current}',
ranks=[0])
block_eigenvalue = self.post_process(block_eigenvalue) block_eigenvalue = self.post_process(block_eigenvalue)
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2019 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import os import os
import re import re
...@@ -9,6 +10,7 @@ import torch ...@@ -9,6 +10,7 @@ import torch
import hashlib import hashlib
from collections import defaultdict, OrderedDict, deque from collections import defaultdict, OrderedDict, deque
from shutil import copyfile from shutil import copyfile
import gc
from torch.nn.modules import Module from torch.nn.modules import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -19,6 +21,7 @@ from typing import Callable, Dict, Union, Iterable ...@@ -19,6 +21,7 @@ from typing import Callable, Dict, Union, Iterable
import deepspeed import deepspeed
from deepspeed import comm as dist
from deepspeed.runtime.utils import see_memory_usage, DummyOptim from deepspeed.runtime.utils import see_memory_usage, DummyOptim
from .zero.offload_config import OffloadDeviceEnum from .zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
...@@ -31,7 +34,7 @@ from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer ...@@ -31,7 +34,7 @@ from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER
...@@ -53,7 +56,7 @@ from deepspeed.compression.constants import \ ...@@ -53,7 +56,7 @@ from deepspeed.compression.constants import \
WEIGHT_QUANTIZE_ROUNDING, \ WEIGHT_QUANTIZE_ROUNDING, \
WEIGHT_QUANTIZE_VERBOSE, \ WEIGHT_QUANTIZE_VERBOSE, \
WEIGHT_QUANTIZE_KERNEL WEIGHT_QUANTIZE_KERNEL
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS
from deepspeed.runtime.sparse_tensor import SparseTensor from deepspeed.runtime.sparse_tensor import SparseTensor
from deepspeed.runtime import lr_schedules from deepspeed.runtime import lr_schedules
...@@ -79,7 +82,7 @@ from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayer ...@@ -79,7 +82,7 @@ from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayer
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from .pipe.module import PipelineModule from .pipe.module import PipelineModule
from .utils import ensure_directory_exists, get_ma_status from .utils import get_ma_status
from ..ops.adam import FusedAdam from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE from ..moe.layer import MoE
...@@ -92,10 +95,7 @@ from deepspeed.utils.logging import print_json_dist, print_configuration ...@@ -92,10 +95,7 @@ from deepspeed.utils.logging import print_json_dist, print_configuration
from deepspeed.accelerator import get_accelerator from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import UtilsBuilder from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.inference.config import DtypeEnum from deepspeed.runtime.config import DtypeEnum
# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None
MEMORY_OPT_ALLREDUCE_SIZE = 500000000 MEMORY_OPT_ALLREDUCE_SIZE = 500000000
...@@ -110,16 +110,13 @@ try: ...@@ -110,16 +110,13 @@ try:
except ImportError: except ImportError:
# Fail silently so we don't spam logs unnecessarily if user isn't using amp # Fail silently so we don't spam logs unnecessarily if user isn't using amp
APEX_INSTALLED = False APEX_INSTALLED = False
pass
def split_half_float_double_sparse(tensors): def split_half_float_double_sparse(tensors):
device_type = get_accelerator().device_name() device_type = get_accelerator().device_name()
supported_types = [ supported_types = [
"torch.{}.HalfTensor".format(device_type), "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type),
"torch.{}.FloatTensor".format(device_type), "torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type),
"torch.{}.DoubleTensor".format(device_type),
"torch.{}.BFloat16Tensor".format(device_type),
SparseTensor.type() SparseTensor.type()
] ]
...@@ -148,6 +145,7 @@ STEP_GLOBAL_TIMER = 'step' ...@@ -148,6 +145,7 @@ STEP_GLOBAL_TIMER = 'step'
class EngineTimers(object): class EngineTimers(object):
r"""Wallclock timers for DeepSpeedEngine""" r"""Wallclock timers for DeepSpeedEngine"""
def __init__(self, enable_micro_timers, enable_global_timers): def __init__(self, enable_micro_timers, enable_global_timers):
self.forward_timers = [] self.forward_timers = []
self.backward_timers = [] self.backward_timers = []
...@@ -164,10 +162,7 @@ class EngineTimers(object): ...@@ -164,10 +162,7 @@ class EngineTimers(object):
self.backward_reduce_timers += [BACKWARD_REDUCE_MICRO_TIMER] self.backward_reduce_timers += [BACKWARD_REDUCE_MICRO_TIMER]
self.step_timers += [STEP_MICRO_TIMER] self.step_timers += [STEP_MICRO_TIMER]
self.micro_timers += [ self.micro_timers += [
FORWARD_MICRO_TIMER, FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER,
BACKWARD_MICRO_TIMER,
BACKWARD_INNER_MICRO_TIMER,
BACKWARD_REDUCE_MICRO_TIMER,
STEP_MICRO_TIMER STEP_MICRO_TIMER
] ]
...@@ -178,16 +173,14 @@ class EngineTimers(object): ...@@ -178,16 +173,14 @@ class EngineTimers(object):
self.backward_reduce_timers += [BACKWARD_REDUCE_GLOBAL_TIMER] self.backward_reduce_timers += [BACKWARD_REDUCE_GLOBAL_TIMER]
self.step_timers += [STEP_GLOBAL_TIMER] self.step_timers += [STEP_GLOBAL_TIMER]
self.global_timers += [ self.global_timers += [
FORWARD_GLOBAL_TIMER, FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER,
BACKWARD_GLOBAL_TIMER,
BACKWARD_INNER_GLOBAL_TIMER,
BACKWARD_REDUCE_GLOBAL_TIMER,
STEP_GLOBAL_TIMER STEP_GLOBAL_TIMER
] ]
class DeepSpeedEngine(Module): class DeepSpeedEngine(Module):
r"""DeepSpeed engine for training.""" r"""DeepSpeed engine for training."""
def __init__( def __init__(
self, self,
args, args,
...@@ -200,7 +193,7 @@ class DeepSpeedEngine(Module): ...@@ -200,7 +193,7 @@ class DeepSpeedEngine(Module):
dist_init_required=None, dist_init_required=None,
collate_fn=None, collate_fn=None,
config=None, config=None,
config_params=None, config_class=None,
dont_change_device=False, dont_change_device=False,
): ):
super(DeepSpeedEngine, self).__init__() super(DeepSpeedEngine, self).__init__()
...@@ -218,6 +211,7 @@ class DeepSpeedEngine(Module): ...@@ -218,6 +211,7 @@ class DeepSpeedEngine(Module):
self.gradient_average = True self.gradient_average = True
self.warn_unscaled_loss = True self.warn_unscaled_loss = True
self.config = config self.config = config
self._config = config_class
self.loaded_checkpoint_mp_world_size = None self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True self.enable_backward_allreduce = True
...@@ -236,8 +230,6 @@ class DeepSpeedEngine(Module): ...@@ -236,8 +230,6 @@ class DeepSpeedEngine(Module):
self.checkpoint_engine = None self.checkpoint_engine = None
global dist
from deepspeed import comm as dist
self._is_gradient_accumulation_boundary = None self._is_gradient_accumulation_boundary = None
self.scale_wrt_gas = None self.scale_wrt_gas = None
...@@ -247,38 +239,15 @@ class DeepSpeedEngine(Module): ...@@ -247,38 +239,15 @@ class DeepSpeedEngine(Module):
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()} self.param_names = {param: name for name, param in model.named_parameters()}
# Set config using config_params for backwards compat
if self.config is None and config_params is not None:
self.config = config_params
from deepspeed.comm import supported_torch_version
# This supported_torch_version check is for torch1.2 compatibility only
if supported_torch_version:
dist.init_distributed(dist_backend=self.dist_backend,
dist_init_required=dist_init_required)
else:
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
if dist_init_required is False:
assert (
dist.is_initialized() is True
), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
else:
if not dist.is_initialized():
dist.init_process_group(backend=self.dist_backend)
self._do_args_sanity_check(args) self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu) self._configure_with_arguments(args, mpu)
self._do_sanity_check() self._do_sanity_check()
see_memory_usage(f"DeepSpeed Engine: After args sanity test", see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
force=self.memory_breakdown())
if mpu is not None: if mpu is not None:
if self.elasticity_enabled(): if self.elasticity_enabled():
if not self.is_elastic_model_parallel_supported(): if not self.is_elastic_model_parallel_supported():
assert not self.elasticity_enabled(), ( assert not self.elasticity_enabled(), ("Elasticity is not currently supported"
"Elasticity is not currently supported" " with model parallelism." " with model parallelism.")
)
self._set_distributed_vars(args) self._set_distributed_vars(args)
...@@ -309,8 +278,7 @@ class DeepSpeedEngine(Module): ...@@ -309,8 +278,7 @@ class DeepSpeedEngine(Module):
monitor_memory=False, monitor_memory=False,
) )
log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", ranks=[0])
ranks=[0])
if self.flops_profiler_enabled(): if self.flops_profiler_enabled():
self.flops_profiler = FlopsProfiler(self.module, self) self.flops_profiler = FlopsProfiler(self.module, self)
...@@ -332,6 +300,10 @@ class DeepSpeedEngine(Module): ...@@ -332,6 +300,10 @@ class DeepSpeedEngine(Module):
if model_parameters is None: if model_parameters is None:
model_parameters = self.module.parameters() model_parameters = self.module.parameters()
# Convert model parameters from generator to list
if not isinstance(model_parameters, list):
model_parameters = list(model_parameters)
if has_optimizer: if has_optimizer:
self._configure_optimizer(optimizer, model_parameters) self._configure_optimizer(optimizer, model_parameters)
self._configure_lr_scheduler(lr_scheduler) self._configure_lr_scheduler(lr_scheduler)
...@@ -346,12 +318,9 @@ class DeepSpeedEngine(Module): ...@@ -346,12 +318,9 @@ class DeepSpeedEngine(Module):
self.sparse_tensor_module_names = set() self.sparse_tensor_module_names = set()
# if self.sparse_gradients_enabled(): # if self.sparse_gradients_enabled():
for name, module in self.module.named_modules(): for name, module in self.module.named_modules():
if isinstance(module, if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled():
(torch.nn.Embedding,
torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled():
self.sparse_tensor_module_names.add(name + ".weight") self.sparse_tensor_module_names.add(name + ".weight")
logger.info( logger.info("Will convert {} to sparse tensor during training".format(name))
"Will convert {} to sparse tensor during training".format(name))
self.save_non_zero_checkpoint = False self.save_non_zero_checkpoint = False
self.save_zero_checkpoint = False self.save_zero_checkpoint = False
...@@ -365,23 +334,19 @@ class DeepSpeedEngine(Module): ...@@ -365,23 +334,19 @@ class DeepSpeedEngine(Module):
self.progressive_layer_drop = self._configure_progressive_layer_drop() self.progressive_layer_drop = self._configure_progressive_layer_drop()
if self.curriculum_enabled_legacy(): if self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy( self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy()
)
if self.random_ltd_enabled(): if self.random_ltd_enabled():
random_ltd_config = self.random_ltd_config() random_ltd_config = self.random_ltd_config()
random_ltd_config[RANDOM_LTD_GLOBAL_BATCH_SIZE] = self.train_batch_size() random_ltd_config[RANDOM_LTD_GLOBAL_BATCH_SIZE] = self.train_batch_size()
random_ltd_config[ random_ltd_config[RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu()
RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu() self.random_ltd_scheduler = self._configure_random_ltd_scheduler(random_ltd_config)
self.random_ltd_scheduler = self._configure_random_ltd_scheduler(
random_ltd_config)
# Engine timers # Engine timers
self.engine_timers = EngineTimers( self.engine_timers = EngineTimers(enable_micro_timers=self.wall_clock_breakdown(),
enable_micro_timers=self.wall_clock_breakdown(), enable_global_timers=self.wall_clock_breakdown()
enable_global_timers=self.wall_clock_breakdown() or self.flops_profiler_enabled())
or self.flops_profiler_enabled())
if self.global_rank == 0: if self.global_rank == 0:
self._config.print("DeepSpeedEngine configuration") self._config.print("DeepSpeedEngine configuration")
...@@ -414,10 +379,8 @@ class DeepSpeedEngine(Module): ...@@ -414,10 +379,8 @@ class DeepSpeedEngine(Module):
if p.requires_grad: if p.requires_grad:
trainable_num_params += n trainable_num_params += n
if self.global_rank == 0: if self.global_rank == 0:
self.autotuning_model_info[ self.autotuning_model_info["num_params"] = num_params * self.mp_world_size
"num_params"] = num_params * self.mp_world_size self.autotuning_model_info["trainable_num_params"] = trainable_num_params * self.mp_world_size
self.autotuning_model_info[
"trainable_num_params"] = trainable_num_params * self.mp_world_size
logger.info(f"model parameter = {num_params}") logger.info(f"model parameter = {num_params}")
...@@ -447,13 +410,10 @@ class DeepSpeedEngine(Module): ...@@ -447,13 +410,10 @@ class DeepSpeedEngine(Module):
ValueError: if ``train_batch_size`` is not divisible by the ValueError: if ``train_batch_size`` is not divisible by the
configured micro-batch size and data parallelism. configured micro-batch size and data parallelism.
""" """
if train_batch_size % (self.train_micro_batch_size_per_gpu() * if train_batch_size % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) != 0:
self.dp_world_size) != 0:
#print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}') #print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}')
raise ValueError( raise ValueError(f'Train batch size must be divisible by micro-batch data parallelism')
f'Train batch size must be divisible by micro-batch data parallelism') new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * self.dp_world_size)
new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() *
self.dp_world_size)
# overwrite config # overwrite config
self._config.train_batch_size = train_batch_size self._config.train_batch_size = train_batch_size
self._config.gradient_accumulation_steps = new_gas self._config.gradient_accumulation_steps = new_gas
...@@ -464,8 +424,7 @@ class DeepSpeedEngine(Module): ...@@ -464,8 +424,7 @@ class DeepSpeedEngine(Module):
def set_custom_curriculum_learning_schedule(self, schedule_func_dict): def set_custom_curriculum_learning_schedule(self, schedule_func_dict):
if self.training_dataloader is not None and self.curriculum_learning_enabled(): if self.training_dataloader is not None and self.curriculum_learning_enabled():
self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule( self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule(schedule_func_dict)
schedule_func_dict)
def get_global_grad_norm(self) -> float: def get_global_grad_norm(self) -> float:
"""Return the 2-norm of all gradients. If there is model parallelism, """Return the 2-norm of all gradients. If there is model parallelism,
...@@ -492,8 +451,7 @@ class DeepSpeedEngine(Module): ...@@ -492,8 +451,7 @@ class DeepSpeedEngine(Module):
elif name in dir(_module): elif name in dir(_module):
return getattr(_module, name) return getattr(_module, name)
else: else:
raise AttributeError( raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
f"'{type(self).__name__}' object has no attribute '{name}'")
def checkpoint_tag_validation_enabled(self): def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled return self._config.checkpoint_tag_validation_enabled
...@@ -567,15 +525,13 @@ class DeepSpeedEngine(Module): ...@@ -567,15 +525,13 @@ class DeepSpeedEngine(Module):
return self._config.data_efficiency_config[DATA_SAMPLING] return self._config.data_efficiency_config[DATA_SAMPLING]
def curriculum_learning_enabled(self): def curriculum_learning_enabled(self):
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]
CURRICULUM_LEARNING_ENABLED]
def curriculum_learning_config(self): def curriculum_learning_config(self):
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING] return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING]
def random_ltd_enabled(self): def random_ltd_enabled(self):
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][ return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][RANDOM_LTD_ENABLED]
RANDOM_LTD_ENABLED]
def random_ltd_config(self): def random_ltd_config(self):
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD] return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD]
...@@ -583,26 +539,20 @@ class DeepSpeedEngine(Module): ...@@ -583,26 +539,20 @@ class DeepSpeedEngine(Module):
def random_ltd_initialize(self): def random_ltd_initialize(self):
assert self.random_ltd_enabled() assert self.random_ltd_enabled()
random_ltd_config = self.random_ltd_config() random_ltd_config = self.random_ltd_config()
random_ltd_queue = deque( random_ltd_queue = deque([x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])])
[x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])])
count = 0 count = 0
for name, layer in self.module.named_modules(): for name, layer in self.module.named_modules():
if isinstance(layer, RandomLayerTokenDrop): if isinstance(layer, RandomLayerTokenDrop):
if len(random_ltd_queue) != 0 and str( if len(random_ltd_queue) != 0 and str(random_ltd_queue[0]) in name: ###[1,2,3]
random_ltd_queue[0]) in name: ###[1,2,3] layer.init_config(random_ltd_config, self.random_ltd_scheduler, count)
layer.init_config(random_ltd_config,
self.random_ltd_scheduler,
count)
random_ltd_queue.popleft() random_ltd_queue.popleft()
count += 1 count += 1
if random_ltd_config[RANDOM_LTD_LAYER_NUM] != count: if random_ltd_config[RANDOM_LTD_LAYER_NUM] != count:
raise ValueError( raise ValueError(f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \
f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \
equivalent to the len of random_ltd_layer_id {count}') equivalent to the len of random_ltd_layer_id {count}')
if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][ if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]:
RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]:
assert self.client_lr_scheduler is None assert self.client_lr_scheduler is None
raise ValueError(f'not yet support') raise ValueError(f'not yet support')
#self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler) #self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler)
...@@ -663,8 +613,7 @@ class DeepSpeedEngine(Module): ...@@ -663,8 +613,7 @@ class DeepSpeedEngine(Module):
def autotuning_profile_model_info(self): def autotuning_profile_model_info(self):
return self.autotuning_enabled( return self.autotuning_enabled(
) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get( ) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get(
"profile", "profile", False)
False)
def sparse_gradients_enabled(self): def sparse_gradients_enabled(self):
return self._config.sparse_gradients_enabled return self._config.sparse_gradients_enabled
...@@ -676,8 +625,7 @@ class DeepSpeedEngine(Module): ...@@ -676,8 +625,7 @@ class DeepSpeedEngine(Module):
return self._config.train_micro_batch_size_per_gpu return self._config.train_micro_batch_size_per_gpu
def optimizer_name(self): def optimizer_name(self):
return (self.client_optimizer.__class__.__name__ return (self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name)
if self.client_optimizer else self._config.optimizer_name)
def optimizer_params(self): def optimizer_params(self):
return self._config.optimizer_params return self._config.optimizer_params
...@@ -695,22 +643,15 @@ class DeepSpeedEngine(Module): ...@@ -695,22 +643,15 @@ class DeepSpeedEngine(Module):
return ( return (
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED], [WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED],
[WEIGHT_QUANTIZE_ENABLED], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_GROUPS],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_GROUPS],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], [WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_CHANGE_RATIO],
[WEIGHT_QUANTIZE_CHANGE_RATIO], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_TYPE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ROUNDING],
[WEIGHT_QUANTIZE_TYPE], self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_VERBOSE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_KERNEL],
[WEIGHT_QUANTIZE_ROUNDING],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_VERBOSE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_KERNEL],
) )
def zero_optimization(self): def zero_optimization(self):
...@@ -719,6 +660,9 @@ class DeepSpeedEngine(Module): ...@@ -719,6 +660,9 @@ class DeepSpeedEngine(Module):
def zero_allow_untested_optimizer(self): def zero_allow_untested_optimizer(self):
return self._config.zero_allow_untested_optimizer return self._config.zero_allow_untested_optimizer
def zero_force_ds_cpu_optimizer(self):
return self._config.zero_force_ds_cpu_optimizer
def zero_reduce_scatter(self): def zero_reduce_scatter(self):
return self._config.zero_config.reduce_scatter return self._config.zero_config.reduce_scatter
...@@ -733,10 +677,7 @@ class DeepSpeedEngine(Module): ...@@ -733,10 +677,7 @@ class DeepSpeedEngine(Module):
def zero_use_cpu_optimizer(self): def zero_use_cpu_optimizer(self):
if self._config.zero_config.offload_optimizer is not None: if self._config.zero_config.offload_optimizer is not None:
return self._config.zero_config.offload_optimizer.device in [ return self._config.zero_config.offload_optimizer.device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]
OffloadDeviceEnum.cpu,
OffloadDeviceEnum.nvme
]
return False return False
def zero_cpu_offload(self): def zero_cpu_offload(self):
...@@ -750,6 +691,9 @@ class DeepSpeedEngine(Module): ...@@ -750,6 +691,9 @@ class DeepSpeedEngine(Module):
def zero_optimization_stage(self): def zero_optimization_stage(self):
return self._config.zero_optimization_stage return self._config.zero_optimization_stage
def mics_shard_size(self):
return self._config.mics_shard_size
def zero_reduce_bucket_size(self): def zero_reduce_bucket_size(self):
return self._config.zero_config.reduce_bucket_size return self._config.zero_config.reduce_bucket_size
...@@ -833,9 +777,11 @@ class DeepSpeedEngine(Module): ...@@ -833,9 +777,11 @@ class DeepSpeedEngine(Module):
res = self._config.communication_data_type res = self._config.communication_data_type
if res is not None: if res is not None:
return res return res
elif self.fp16_enabled() or self.zero_optimization_stage():
if self.fp16_enabled():
return torch.float16 return torch.float16
elif self.bfloat16_enabled():
if self.bfloat16_enabled():
return torch.bfloat16 return torch.bfloat16
return torch.float32 return torch.float32
...@@ -897,14 +843,11 @@ class DeepSpeedEngine(Module): ...@@ -897,14 +843,11 @@ class DeepSpeedEngine(Module):
# First check for scheduler in json configuration # First check for scheduler in json configuration
lr_scheduler = self._scheduler_from_config(self.optimizer) lr_scheduler = self._scheduler_from_config(self.optimizer)
if lr_scheduler: if lr_scheduler:
log_dist( log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0])
f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}",
ranks=[0])
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
else: else:
if isinstance(client_lr_scheduler, Callable): if isinstance(client_lr_scheduler, Callable):
log_dist('DeepSpeed using client callable to create LR scheduler', log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0])
ranks=[0])
self.lr_scheduler = client_lr_scheduler(self.basic_optimizer) self.lr_scheduler = client_lr_scheduler(self.basic_optimizer)
else: else:
log_dist('DeepSpeed using client LR scheduler', ranks=[0]) log_dist('DeepSpeed using client LR scheduler', ranks=[0])
...@@ -919,12 +862,9 @@ class DeepSpeedEngine(Module): ...@@ -919,12 +862,9 @@ class DeepSpeedEngine(Module):
try: try:
from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \
NebulaCheckpointEngine NebulaCheckpointEngine
self.checkpoint_engine = NebulaCheckpointEngine( self.checkpoint_engine = NebulaCheckpointEngine(config_params=self._config.nebula_config)
config_params=self._config.nebula_config)
except ImportError as err: except ImportError as err:
logger.error( logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}")
f"No torch_nebula was found! Will fall back to torch.save. Details: {err}"
)
self.checkpoint_engine = TorchCheckpointEngine() self.checkpoint_engine = TorchCheckpointEngine()
dp_rank = self.global_rank dp_rank = self.global_rank
...@@ -936,8 +876,7 @@ class DeepSpeedEngine(Module): ...@@ -936,8 +876,7 @@ class DeepSpeedEngine(Module):
# only the first data parallel process needs to store the model checkpoint # only the first data parallel process needs to store the model checkpoint
# if you want to use node local storage this must be done by rank 0 on each # if you want to use node local storage this must be done by rank 0 on each
# node # node
self.save_non_zero_checkpoint = ( self.save_non_zero_checkpoint = (rank == 0) or self.zero_optimization_partition_weights()
rank == 0) or self.zero_optimization_partition_weights()
if self.zero_optimization() or self.bfloat16_enabled(): if self.zero_optimization() or self.bfloat16_enabled():
param_rank = dist.get_rank(group=self.optimizer.dp_process_group) param_rank = dist.get_rank(group=self.optimizer.dp_process_group)
...@@ -952,9 +891,8 @@ class DeepSpeedEngine(Module): ...@@ -952,9 +891,8 @@ class DeepSpeedEngine(Module):
if hasattr(lr_schedules, scheduler_name): if hasattr(lr_schedules, scheduler_name):
scheduler = getattr(lr_schedules, scheduler_name) scheduler = getattr(lr_schedules, scheduler_name)
else: else:
assert hasattr( assert hasattr(torch.optim.lr_scheduler,
torch.optim.lr_scheduler, scheduler_name scheduler_name), f"DeepSpeed does not recognize LR scheduler {scheduler_name}"
), f"DeepSpeed does not recognize LR scheduler {scheduler_name}"
scheduler = getattr(torch.optim.lr_scheduler, scheduler_name) scheduler = getattr(torch.optim.lr_scheduler, scheduler_name)
...@@ -965,9 +903,7 @@ class DeepSpeedEngine(Module): ...@@ -965,9 +903,7 @@ class DeepSpeedEngine(Module):
return None return None
def _set_distributed_vars(self, args): def _set_distributed_vars(self, args):
device_rank = args.device_rank if args is not None and hasattr( device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank
args,
'device_rank') else self.local_rank
if device_rank >= 0: if device_rank >= 0:
get_accelerator().set_device(device_rank) get_accelerator().set_device(device_rank)
self.device = torch.device(get_accelerator().device_name(), device_rank) self.device = torch.device(get_accelerator().device_name(), device_rank)
...@@ -996,48 +932,23 @@ class DeepSpeedEngine(Module): ...@@ -996,48 +932,23 @@ class DeepSpeedEngine(Module):
if hasattr(args, 'local_rank'): if hasattr(args, 'local_rank'):
args.local_rank = self.local_rank args.local_rank = self.local_rank
if self.config is None:
self.config = (args.deepspeed_config
if hasattr(args,
"deepspeed_config") else None)
self._config = DeepSpeedConfig(self.config, mpu)
# Validate command line arguments # Validate command line arguments
def _do_args_sanity_check(self, args): def _do_args_sanity_check(self, args):
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning(
"************ --deepscale_config is deprecated, please use --deepspeed_config ************"
)
if hasattr(args, "deepspeed_config"):
assert (
args.deepspeed_config is None
), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args.deepspeed_config = args.deepscale_config
assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \ assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \
"variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \ "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \
"different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." "different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed."
if hasattr(args, 'local_rank') and args.local_rank != None: if hasattr(args, 'local_rank') and args.local_rank != None:
assert isinstance( assert isinstance(args.local_rank,
args.local_rank, int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}"
if args.local_rank >= 0: if args.local_rank >= 0:
env_local_rank = int(os.environ.get("LOCAL_RANK")) env_local_rank = int(os.environ.get("LOCAL_RANK"))
assert ( assert (
env_local_rank == args.local_rank env_local_rank == args.local_rank
), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}." ), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}."
if self.config is None:
assert (
hasattr(
args, "deepspeed_config") and args.deepspeed_config is not None
), "DeepSpeed requires --deepspeed_config to specify configuration file"
def _is_supported_optimizer(self, optimizer_name): def _is_supported_optimizer(self, optimizer_name):
return (optimizer_name in DEEPSPEED_OPTIMIZERS return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None)
or getattr(torch.optim,
optimizer_name,
None) is not None)
def _supported_optims(self): def _supported_optims(self):
FairseqOptimizer = None FairseqOptimizer = None
...@@ -1062,18 +973,11 @@ class DeepSpeedEngine(Module): ...@@ -1062,18 +973,11 @@ class DeepSpeedEngine(Module):
if not self.client_optimizer: if not self.client_optimizer:
if self.optimizer_name() is not None: if self.optimizer_name() is not None:
assert self._is_supported_optimizer( assert self._is_supported_optimizer(
self.optimizer_name() self.optimizer_name()), "{} is not a supported DeepSpeed Optimizer".format(self.optimizer_name())
), "{} is not a supported DeepSpeed Optimizer".format(
self.optimizer_name()
)
if (self.optimizer_name() == LAMB_OPTIMIZER if (self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER):
or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER): assert (self.dynamic_loss_scale()), "DeepSpeed {} optimizer requires dynamic loss scaling".format(
assert ( self.optimizer_name())
self.dynamic_loss_scale()
), "DeepSpeed {} optimizer requires dynamic loss scaling".format(
self.optimizer_name()
)
# Detect invalid combinations of client optimizer and client scheduler # Detect invalid combinations of client optimizer and client scheduler
if isinstance(self.client_lr_scheduler, _LRScheduler): if isinstance(self.client_lr_scheduler, _LRScheduler):
...@@ -1081,6 +985,7 @@ class DeepSpeedEngine(Module): ...@@ -1081,6 +985,7 @@ class DeepSpeedEngine(Module):
f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated' f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated'
def _broadcast_model(self): def _broadcast_model(self):
def is_replicated(p): def is_replicated(p):
if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE: if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE:
return False return False
...@@ -1095,20 +1000,15 @@ class DeepSpeedEngine(Module): ...@@ -1095,20 +1000,15 @@ class DeepSpeedEngine(Module):
group=self.expert_data_parallel_group[p.group_name]) group=self.expert_data_parallel_group[p.group_name])
else: else:
if torch.is_tensor(p) and is_replicated(p): if torch.is_tensor(p) and is_replicated(p):
dist.broadcast(p, dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.data_parallel_group)
groups._get_broadcast_src_rank(),
group=self.data_parallel_group)
@staticmethod @staticmethod
def __check_params(model: Module, dtype: torch.dtype) -> None: def __check_params(model: Module, dtype: torch.dtype) -> None:
return return
if not all(param.dtype == dtype if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0:
for param in model.parameters()) and dist.get_rank() == 0: raise ValueError(f"{dtype} is enabled but the following parameters have dtype that is "
raise ValueError( f"not {dtype}: "
f"{dtype} is enabled but the following parameters have dtype that is " f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}")
f"not {dtype}: "
f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}"
)
def _set_client_model(self, model): def _set_client_model(self, model):
# register client model in _modules so that nn.module methods work correctly # register client model in _modules so that nn.module methods work correctly
...@@ -1122,14 +1022,12 @@ class DeepSpeedEngine(Module): ...@@ -1122,14 +1022,12 @@ class DeepSpeedEngine(Module):
if self.fp16_enabled(): if self.fp16_enabled():
if self.zero_optimization_partition_weights() and any( if self.zero_optimization_partition_weights() and any(
[hasattr(param, [hasattr(param, "ds_id") for param in self.module.parameters()]):
"ds_id") for param in self.module.parameters()]):
self.__check_params(self.module, torch.half) self.__check_params(self.module, torch.half)
self.module.half() self.module.half()
elif self.bfloat16_enabled(): elif self.bfloat16_enabled():
if self.zero_optimization_partition_weights() and any( if self.zero_optimization_partition_weights() and any(
hasattr(param, hasattr(param, 'ds_id') for param in self.module.parameters()):
'ds_id') for param in self.module.parameters()):
self.__check_params(self.module, torch.bfloat16) self.__check_params(self.module, torch.bfloat16)
self.module.bfloat16() self.module.bfloat16()
else: else:
...@@ -1183,8 +1081,7 @@ class DeepSpeedEngine(Module): ...@@ -1183,8 +1081,7 @@ class DeepSpeedEngine(Module):
return [id(param) for param in group] return [id(param) for param in group]
occurrence = sum([ occurrence = sum([
ids_list(group['params']).count(param_id) ids_list(group['params']).count(param_id) if param_id in ids_list(group['params']) else 0
if param_id in ids_list(group['params']) else 0
for group in optimizer.param_groups for group in optimizer.param_groups
]) ])
assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour." assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour."
...@@ -1204,9 +1101,7 @@ class DeepSpeedEngine(Module): ...@@ -1204,9 +1101,7 @@ class DeepSpeedEngine(Module):
), 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' ), 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
if self.global_rank == 0: if self.global_rank == 0:
logger.warning( logger.warning("**** You are using ZeRO with an untested optimizer, proceed with caution *****")
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage( if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage(
) == 1: ) == 1:
...@@ -1214,23 +1109,19 @@ class DeepSpeedEngine(Module): ...@@ -1214,23 +1109,19 @@ class DeepSpeedEngine(Module):
if model_dtype != grad_accum_dtype: if model_dtype != grad_accum_dtype:
raise NotImplementedError( raise NotImplementedError(
"Model data type and gradient accumulation data type must be equal to use ZeRO" "Model data type and gradient accumulation data type must be equal to use ZeRO")
)
return ZERO_OPTIMIZATION return ZERO_OPTIMIZATION
elif amp_enabled: elif amp_enabled:
if model_dtype != grad_accum_dtype: if model_dtype != grad_accum_dtype:
raise NotImplementedError( raise NotImplementedError(
"Model data type and gradient accumulation data type must be equal to use Amp" "Model data type and gradient accumulation data type must be equal to use Amp")
)
if model_dtype == torch.bfloat16 or model_dtype == torch.float16: if model_dtype == torch.bfloat16 or model_dtype == torch.float16:
raise NotImplementedError( raise NotImplementedError("Cannot enable both amp with (legacy) fp16 or bfloat16 mode")
"Cannot enable both amp with (legacy) fp16 or bfloat16 mode")
try: try:
logger.info("Initializing Apex amp from: {}".format(amp.__path__)) logger.info("Initializing Apex amp from: {}".format(amp.__path__))
except NameError: except NameError:
# If apex/amp is available it will be imported above # If apex/amp is available it will be imported above
raise RuntimeError( raise RuntimeError("Unable to import apex/amp, please make sure it is installed")
"Unable to import apex/amp, please make sure it is installed")
return AMP return AMP
# data type checks # data type checks
elif model_dtype == grad_accum_dtype: elif model_dtype == grad_accum_dtype:
...@@ -1244,8 +1135,7 @@ class DeepSpeedEngine(Module): ...@@ -1244,8 +1135,7 @@ class DeepSpeedEngine(Module):
elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32: elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32:
return BFLOAT16 return BFLOAT16
else: else:
raise NotImplementedError( raise NotImplementedError("unsupported mix of model dtype and gradient accummulation type")
"unsupported mix of model dtype and gradient accummulation type")
return None return None
...@@ -1256,27 +1146,26 @@ class DeepSpeedEngine(Module): ...@@ -1256,27 +1146,26 @@ class DeepSpeedEngine(Module):
client_optimizer.param_groups[:] = [ client_optimizer.param_groups[:] = [
pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0 pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0
] ]
log_dist( log_dist("Removing param_group that has no 'params' in the client Optimizer", ranks=[0])
"Removing param_group that has no 'params' in the client Optimizer",
ranks=[0])
basic_optimizer = client_optimizer basic_optimizer = client_optimizer
log_dist('Using client Optimizer as basic optimizer', ranks=[0]) log_dist('Using client Optimizer as basic optimizer', ranks=[0])
else: else:
basic_optimizer = client_optimizer(model_parameters) basic_optimizer = client_optimizer(model_parameters)
log_dist('Using client callable to create basic optimizer', ranks=[0]) log_dist('Using client callable to create basic optimizer', ranks=[0])
if self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam):
if self.zero_force_ds_cpu_optimizer():
msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.'
raise ZeRORuntimeException(msg)
else: else:
basic_optimizer = self._configure_basic_optimizer(model_parameters) basic_optimizer = self._configure_basic_optimizer(model_parameters)
log_dist( log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0])
f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer",
ranks=[0])
self._check_for_duplicates(basic_optimizer) self._check_for_duplicates(basic_optimizer)
self.basic_optimizer = basic_optimizer self.basic_optimizer = basic_optimizer
log_dist("DeepSpeed Basic Optimizer = {}".format( log_dist("DeepSpeed Basic Optimizer = {}".format(basic_optimizer.__class__.__name__), ranks=[0])
basic_optimizer.__class__.__name__),
ranks=[0])
optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer) optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer)
...@@ -1285,9 +1174,7 @@ class DeepSpeedEngine(Module): ...@@ -1285,9 +1174,7 @@ class DeepSpeedEngine(Module):
elif optimizer_wrapper == AMP: elif optimizer_wrapper == AMP:
amp_params = self.amp_params() amp_params = self.amp_params()
log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0]) log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0])
model, self.optimizer = amp.initialize( model, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
self.module, basic_optimizer, **amp_params
)
self._set_client_model(model) self._set_client_model(model)
self._broadcast_model() self._broadcast_model()
# TODO: maybe need to broadcast experts differently? # TODO: maybe need to broadcast experts differently?
...@@ -1298,8 +1185,7 @@ class DeepSpeedEngine(Module): ...@@ -1298,8 +1185,7 @@ class DeepSpeedEngine(Module):
else: else:
self.optimizer = basic_optimizer self.optimizer = basic_optimizer
log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), ranks=[0])
ranks=[0])
self.compression_scheduler = self._configure_compression_scheduler() self.compression_scheduler = self._configure_compression_scheduler()
self.quantizer = self._configure_quantization() self.quantizer = self._configure_quantization()
...@@ -1314,32 +1200,24 @@ class DeepSpeedEngine(Module): ...@@ -1314,32 +1200,24 @@ class DeepSpeedEngine(Module):
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
) )
if self.optimizer_name() in [ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER]: if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False) torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT) adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT)
# Optimizer name of Adam forces AdamW logic unless adam_w_mode is explicitly set # Optimizer name of Adam forces AdamW logic unless adam_w_mode is explicitly set
effective_adam_w_mode = self.optimizer_name( effective_adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER or adam_w_mode
) == ADAMW_OPTIMIZER or adam_w_mode
if torch_adam: if torch_adam:
if not effective_adam_w_mode: if not effective_adam_w_mode:
optimizer = torch.optim.Adam(model_parameters, optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
**optimizer_parameters)
else: else:
optimizer = torch.optim.AdamW(model_parameters, optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters)
**optimizer_parameters)
else: else:
if self.zero_use_cpu_optimizer(): if self.zero_use_cpu_optimizer():
if self.optimizer_name() == ADAGRAD_OPTIMIZER: from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad optimizer = DeepSpeedCPUAdam(model_parameters,
optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters,
**optimizer_parameters) adamw_mode=effective_adam_w_mode)
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=effective_adam_w_mode)
else: else:
from deepspeed.ops.adam import FusedAdam from deepspeed.ops.adam import FusedAdam
...@@ -1349,6 +1227,12 @@ class DeepSpeedEngine(Module): ...@@ -1349,6 +1227,12 @@ class DeepSpeedEngine(Module):
adam_w_mode=effective_adam_w_mode, adam_w_mode=effective_adam_w_mode,
) )
elif self.optimizer_name() == ADAGRAD_OPTIMIZER:
if self.zero_use_cpu_optimizer():
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters)
else:
optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER: elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb from deepspeed.ops.lamb import FusedLamb
...@@ -1359,26 +1243,21 @@ class DeepSpeedEngine(Module): ...@@ -1359,26 +1243,21 @@ class DeepSpeedEngine(Module):
optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters) optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled(): if not self.fp16_enabled():
logger.warning( logger.warning(f"Currently the convergence of 1-bit Adam is only verified under FP16")
f"Currently the convergence of 1-bit Adam is only verified under FP16"
)
elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER: elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER:
assert not self.zero_optimization(), "0/1 Adam is not compatible with ZeRO" assert not self.zero_optimization(), "0/1 Adam is not compatible with ZeRO"
from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters) optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled(): if not self.fp16_enabled():
logger.warning( logger.warning(f'Currently the convergence of 0/1 Adam is only verified under FP16')
f'Currently the convergence of 0/1 Adam is only verified under FP16')
elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER: elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER:
assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO" assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO"
from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb
optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters) optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled(): if not self.fp16_enabled():
logger.warning( logger.warning(f"Currently the convergence of 1-bit Lamb is only verified under FP16")
f"Currently the convergence of 1-bit Lamb is only verified under FP16"
)
else: else:
torch_optimizer = getattr(torch.optim, self.optimizer_name()) torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters) optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
...@@ -1403,7 +1282,8 @@ class DeepSpeedEngine(Module): ...@@ -1403,7 +1282,8 @@ class DeepSpeedEngine(Module):
use_quantizer_kernel, use_quantizer_kernel,
) = self.quantize_training() ) = self.quantize_training()
if quantize_enabled and not quantize_weight_in_forward: if quantize_enabled and not quantize_weight_in_forward:
assert self.fp16_enabled(), "MoQ (quantize in optimization step) weight quantization is only supported for FP16" assert self.fp16_enabled(
), "MoQ (quantize in optimization step) weight quantization is only supported for FP16"
quantizer = None quantizer = None
if quantize_enabled and not quantize_weight_in_forward: if quantize_enabled and not quantize_weight_in_forward:
from deepspeed.runtime.quantize import Quantizer from deepspeed.runtime.quantize import Quantizer
...@@ -1447,9 +1327,7 @@ class DeepSpeedEngine(Module): ...@@ -1447,9 +1327,7 @@ class DeepSpeedEngine(Module):
has_moe_layers=self.has_moe_layers, has_moe_layers=self.has_moe_layers,
) )
else: else:
log_dist( log_dist(f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}', ranks=[0])
f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}',
ranks=[0])
optimizer = FP16_Optimizer( optimizer = FP16_Optimizer(
optimizer, optimizer,
deepspeed=self, deepspeed=self,
...@@ -1460,8 +1338,7 @@ class DeepSpeedEngine(Module): ...@@ -1460,8 +1338,7 @@ class DeepSpeedEngine(Module):
has_moe_layers=self.has_moe_layers, has_moe_layers=self.has_moe_layers,
) )
else: else:
log_dist(f'Creating fp16 unfused optimizer with dynamic loss scale', log_dist(f'Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0])
ranks=[0])
optimizer = FP16_UnfusedOptimizer( optimizer = FP16_UnfusedOptimizer(
optimizer, optimizer,
deepspeed=self, deepspeed=self,
...@@ -1484,19 +1361,20 @@ class DeepSpeedEngine(Module): ...@@ -1484,19 +1361,20 @@ class DeepSpeedEngine(Module):
log_dist('Creating BF16 optimizer', ranks=[0]) log_dist('Creating BF16 optimizer', ranks=[0])
timers = self.timers if self.wall_clock_breakdown() else None timers = self.timers if self.wall_clock_breakdown() else None
optimizer = BF16_Optimizer( optimizer = BF16_Optimizer(optimizer,
optimizer, self.param_names,
self.param_names, mpu=self.mpu,
mpu=self.mpu, clip_grad=clip_grad,
clip_grad=clip_grad, allgather_bucket_size=self.zero_allgather_bucket_size(),
allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.data_parallel_group,
dp_process_group=self.data_parallel_group, timers=timers)
timers=timers)
return optimizer return optimizer
def _configure_zero_optimizer(self, optimizer): def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage() zero_stage = self.zero_optimization_stage()
mics_shard_size = self.mics_shard_size()
model_dtype, grad_accum_dtype = self.get_data_types() model_dtype, grad_accum_dtype = self.get_data_types()
timers = self.timers if self.wall_clock_breakdown() else None timers = self.timers if self.wall_clock_breakdown() else None
...@@ -1514,8 +1392,7 @@ class DeepSpeedEngine(Module): ...@@ -1514,8 +1392,7 @@ class DeepSpeedEngine(Module):
round_robin_gradients = self.zero_round_robin_gradients() round_robin_gradients = self.zero_round_robin_gradients()
assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage) assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage)
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
ranks=[0])
# Overlap and contiguous grads are meaningless in stage 1 and are ignored # Overlap and contiguous grads are meaningless in stage 1 and are ignored
if zero_stage == ZeroStageEnum.optimizer_states: if zero_stage == ZeroStageEnum.optimizer_states:
overlap_comm = False overlap_comm = False
...@@ -1526,9 +1403,7 @@ class DeepSpeedEngine(Module): ...@@ -1526,9 +1403,7 @@ class DeepSpeedEngine(Module):
if isinstance(self.module, PipelineModule): if isinstance(self.module, PipelineModule):
if overlap_comm: if overlap_comm:
logger.warning( logger.warning("Pipeline parallelism does not support overlapped communication, will be disabled.")
"Pipeline parallelism does not support overlapped communication, will be disabled."
)
overlap_comm = False overlap_comm = False
optimizer = DeepSpeedZeroOptimizer( optimizer = DeepSpeedZeroOptimizer(
optimizer, optimizer,
...@@ -1542,10 +1417,8 @@ class DeepSpeedEngine(Module): ...@@ -1542,10 +1417,8 @@ class DeepSpeedEngine(Module):
reduce_bucket_size=self.zero_reduce_bucket_size(), reduce_bucket_size=self.zero_reduce_bucket_size(),
allgather_bucket_size=self.zero_allgather_bucket_size(), allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group, dp_process_group=self.data_parallel_group,
expert_parallel_group=self.expert_parallel_group expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None,
if self.has_moe_layers else None, expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None,
expert_data_parallel_group=self.expert_data_parallel_group
if self.has_moe_layers else None,
reduce_scatter=self.zero_reduce_scatter(), reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=overlap_comm, overlap_comm=overlap_comm,
cpu_offload=self.zero_cpu_offload(), cpu_offload=self.zero_cpu_offload(),
...@@ -1557,8 +1430,7 @@ class DeepSpeedEngine(Module): ...@@ -1557,8 +1430,7 @@ class DeepSpeedEngine(Module):
partition_grads=zero_stage == ZeroStageEnum.gradients, partition_grads=zero_stage == ZeroStageEnum.gradients,
round_robin_gradients=round_robin_gradients, round_robin_gradients=round_robin_gradients,
has_moe_layers=self.has_moe_layers, has_moe_layers=self.has_moe_layers,
fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients( fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(),
),
communication_data_type=self.communication_data_type, communication_data_type=self.communication_data_type,
elastic_checkpoint=self.zero_elastic_checkpoint()) elastic_checkpoint=self.zero_elastic_checkpoint())
...@@ -1566,21 +1438,27 @@ class DeepSpeedEngine(Module): ...@@ -1566,21 +1438,27 @@ class DeepSpeedEngine(Module):
assert not self.has_moe_layers, "MoE not supported with Stage 3" assert not self.has_moe_layers, "MoE not supported with Stage 3"
if isinstance(optimizer, DummyOptim): if isinstance(optimizer, DummyOptim):
log_dist("Creating ZeRO Offload", ranks=[0]) log_dist("Creating ZeRO Offload", ranks=[0])
optimizer = DeepSpeedZeRoOffload( optimizer = DeepSpeedZeRoOffload(self.module,
self.module, timers=timers,
timers=timers, ds_config=self.config,
ds_config=self.config, overlap_comm=self.zero_overlap_comm(),
overlap_comm=self.zero_overlap_comm(), prefetch_bucket_size=self.zero_prefetch_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(), max_reuse_distance=self.zero_max_reuse_distance(),
max_reuse_distance=self.zero_max_reuse_distance(), max_live_parameters=self.zero_max_live_parameters(),
max_live_parameters=self.zero_max_live_parameters(), param_persistence_threshold=self.zero_param_persistence_threshold(),
param_persistence_threshold=self.zero_param_persistence_threshold(), model_persistence_threshold=self.zero_model_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(), offload_param_config=self.zero_offload_param(),
offload_param_config=self.zero_offload_param(), mpu=self.mpu)
mpu=self.mpu)
else: else:
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', log_dist(
ranks=[0]) f'Creating fp16 ZeRO stage {zero_stage} optimizer,'
f' MiCS is enabled {mics_shard_size>0},'
f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}',
ranks=[0])
if mics_shard_size > 0:
return self._return_mics_optimizer(optimizer, timers)
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
optimizer = DeepSpeedZeroOptimizer_Stage3( optimizer = DeepSpeedZeroOptimizer_Stage3(
self.module, self.module,
...@@ -1616,6 +1494,37 @@ class DeepSpeedEngine(Module): ...@@ -1616,6 +1494,37 @@ class DeepSpeedEngine(Module):
return optimizer return optimizer
def _return_mics_optimizer(self, basic_optimizer, timers):
from deepspeed.runtime.zero.mics import MiCS_Optimizer
optimizer = MiCS_Optimizer(self.module,
basic_optimizer,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
sub_group_size=self.zero_sub_group_size(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
aio_config=self.aio_config(),
communication_data_type=self.communication_data_type)
return optimizer
def _configure_eigenvalue(self): def _configure_eigenvalue(self):
eigenvalue = Eigenvalue( eigenvalue = Eigenvalue(
verbose=self.eigenvalue_verbose(), verbose=self.eigenvalue_verbose(),
...@@ -1644,9 +1553,7 @@ class DeepSpeedEngine(Module): ...@@ -1644,9 +1553,7 @@ class DeepSpeedEngine(Module):
@staticmethod @staticmethod
def is_iterable_style_dataset(obj): def is_iterable_style_dataset(obj):
return isinstance(obj, return isinstance(obj, torch.utils.data.IterableDataset) # hasattr(obj, "__iter__") should work as well
torch.utils.data.IterableDataset
) # hasattr(obj, "__iter__") should work as well
def dataloader_drop_last(self): def dataloader_drop_last(self):
return self._config.dataloader_drop_last return self._config.dataloader_drop_last
...@@ -1669,8 +1576,7 @@ class DeepSpeedEngine(Module): ...@@ -1669,8 +1576,7 @@ class DeepSpeedEngine(Module):
data_sampler=None, data_sampler=None,
collate_fn=None, collate_fn=None,
num_local_io_workers=None): num_local_io_workers=None):
if not (self.is_map_style_dataset(dataset) if not (self.is_map_style_dataset(dataset) or self.is_iterable_style_dataset(dataset)):
or self.is_iterable_style_dataset(dataset)):
raise ValueError("Training data must be a torch Dataset") raise ValueError("Training data must be a torch Dataset")
if batch_size is None: if batch_size is None:
...@@ -1702,33 +1608,26 @@ class DeepSpeedEngine(Module): ...@@ -1702,33 +1608,26 @@ class DeepSpeedEngine(Module):
deepspeed_dataloader_config = {} deepspeed_dataloader_config = {}
if self.curriculum_learning_enabled(): if self.curriculum_learning_enabled():
deepspeed_dataloader_config = { deepspeed_dataloader_config = {
CURRICULUM_LEARNING: CURRICULUM_LEARNING: self.curriculum_learning_enabled(),
self.curriculum_learning_enabled(), DATA_EFFICIENCY: self.data_efficiency_config(),
DATA_EFFICIENCY: DATA_PARALLEL_GROUP: self.data_parallel_group,
self.data_efficiency_config(), GRADIENT_ACCUMULATION_STEPS: self.gradient_accumulation_steps(),
DATA_PARALLEL_GROUP: GLOBAL_RANK: self.global_rank,
self.data_parallel_group, DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]
GRADIENT_ACCUMULATION_STEPS:
self.gradient_accumulation_steps(),
GLOBAL_RANK:
self.global_rank,
DATA_SAMPLING_NUM_WORKERS:
self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]
} }
return DeepSpeedDataLoader( return DeepSpeedDataLoader(dataset=dataset,
dataset=dataset, batch_size=batch_size,
batch_size=batch_size, pin_memory=pin_memory,
pin_memory=pin_memory, collate_fn=collate_fn,
collate_fn=collate_fn, local_rank=self.local_rank,
local_rank=self.local_rank, tput_timer=deepspeed_io_timer,
tput_timer=deepspeed_io_timer, num_local_io_workers=num_local_io_workers,
num_local_io_workers=num_local_io_workers, data_sampler=data_sampler,
data_sampler=data_sampler, data_parallel_world_size=data_parallel_world_size,
data_parallel_world_size=data_parallel_world_size, data_parallel_rank=data_parallel_rank,
data_parallel_rank=data_parallel_rank, dataloader_drop_last=self.dataloader_drop_last(),
dataloader_drop_last=self.dataloader_drop_last(), deepspeed_dataloader_config=deepspeed_dataloader_config)
deepspeed_dataloader_config=deepspeed_dataloader_config)
def train(self, mode=True): def train(self, mode=True):
r"""""" r""""""
...@@ -1755,9 +1654,7 @@ class DeepSpeedEngine(Module): ...@@ -1755,9 +1654,7 @@ class DeepSpeedEngine(Module):
else: else:
scaled_loss = prescaled_loss scaled_loss = prescaled_loss
if self.warn_unscaled_loss: if self.warn_unscaled_loss:
logger.warning( logger.warning(f"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}")
f"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}"
)
self.warn_unscaled_loss = False self.warn_unscaled_loss = False
return scaled_loss return scaled_loss
...@@ -1775,9 +1672,8 @@ class DeepSpeedEngine(Module): ...@@ -1775,9 +1672,8 @@ class DeepSpeedEngine(Module):
else: else:
see_memory_usage("Engine before forward", force=self.memory_breakdown()) see_memory_usage("Engine before forward", force=self.memory_breakdown())
flops_profiler_active = (self.flops_profiler_enabled() and self.global_steps flops_profiler_active = (self.flops_profiler_enabled()
== self.flops_profiler_profile_step() and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0)
and self.global_rank == 0)
# used to check quantization happens at step 0! # used to check quantization happens at step 0!
if self.global_steps == 0 and hasattr(self, "compression_scheduler"): if self.global_steps == 0 and hasattr(self, "compression_scheduler"):
...@@ -1806,10 +1702,7 @@ class DeepSpeedEngine(Module): ...@@ -1806,10 +1702,7 @@ class DeepSpeedEngine(Module):
if self.module.training and self.curriculum_enabled_legacy(): if self.module.training and self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1) self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen": if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
kwargs.update({ kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})
"curriculum_seqlen":
self.curriculum_scheduler_legacy.get_current_difficulty()
})
if self.module.training and self.random_ltd_enabled(): if self.module.training and self.random_ltd_enabled():
self.random_ltd_scheduler.update_seq(self.global_steps) self.random_ltd_scheduler.update_seq(self.global_steps)
...@@ -1819,7 +1712,6 @@ class DeepSpeedEngine(Module): ...@@ -1819,7 +1712,6 @@ class DeepSpeedEngine(Module):
# we are in a forward pass. # we are in a forward pass.
for module in self.module.modules(): for module in self.module.modules():
module._parameters._in_forward = True module._parameters._in_forward = True
pass
self._start_timers(self.engine_timers.forward_timers) self._start_timers(self.engine_timers.forward_timers)
...@@ -1844,9 +1736,7 @@ class DeepSpeedEngine(Module): ...@@ -1844,9 +1736,7 @@ class DeepSpeedEngine(Module):
if self.autotuning_profile_model_info(): if self.autotuning_profile_model_info():
activation_mem = get_ma_status() - ma activation_mem = get_ma_status() - ma
self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem
print_json_dist(self.autotuning_model_info, print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path())
[0],
path=self.autotuning_model_info_path())
exit() exit()
else: else:
see_memory_usage("Engine after forward", force=self.memory_breakdown()) see_memory_usage("Engine after forward", force=self.memory_breakdown())
...@@ -1897,27 +1787,21 @@ class DeepSpeedEngine(Module): ...@@ -1897,27 +1787,21 @@ class DeepSpeedEngine(Module):
f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled' f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled'
# Pass (PP) gas boundary flag to optimizer (required for zero) # Pass (PP) gas boundary flag to optimizer (required for zero)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
)
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients(): if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue() self.optimizer.overlapping_partition_gradients_reduce_epilogue()
# Communicate only at gradient accumulation boundaries # Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary(): elif self.is_gradient_accumulation_boundary():
if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states and hasattr(
self.optimizer.reduce_gradients( self.optimizer, 'reduce_gradients'):
pipeline_parallel=self.pipeline_parallelism) self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)
else: else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
@instrument_w_nvtx @instrument_w_nvtx
def backward(self, def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True):
loss,
allreduce_gradients=True,
release_loss=False,
retain_graph=False,
scale_wrt_gas=True):
r"""Execute backward pass on the loss r"""Execute backward pass on the loss
Arguments: Arguments:
loss: Torch tensor on which to execute backward propagation loss: Torch tensor on which to execute backward propagation
...@@ -1932,9 +1816,7 @@ class DeepSpeedEngine(Module): ...@@ -1932,9 +1816,7 @@ class DeepSpeedEngine(Module):
scale_wrt_gas = self.scale_wrt_gas scale_wrt_gas = self.scale_wrt_gas
if not allreduce_gradients: if not allreduce_gradients:
logger.warning( logger.warning(f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed")
f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed"
)
# scale loss w.r.t. gradient accumulation if needed # scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1 and scale_wrt_gas: if self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
...@@ -1959,16 +1841,13 @@ class DeepSpeedEngine(Module): ...@@ -1959,16 +1841,13 @@ class DeepSpeedEngine(Module):
self._start_timers(self.engine_timers.backward_inner_timers) self._start_timers(self.engine_timers.backward_inner_timers)
if self.zero_optimization(): if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
)
self.optimizer.backward(loss, retain_graph=retain_graph) self.optimizer.backward(loss, retain_graph=retain_graph)
elif self.amp_enabled(): elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries # AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not self.is_gradient_accumulation_boundary() delay_unscale = not self.is_gradient_accumulation_boundary()
with amp.scale_loss(loss, with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
self.optimizer,
delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward(retain_graph=retain_graph) scaled_loss.backward(retain_graph=retain_graph)
elif self.fp16_enabled(): elif self.fp16_enabled():
if self.eigenvalue_enabled(): if self.eigenvalue_enabled():
...@@ -2051,22 +1930,17 @@ class DeepSpeedEngine(Module): ...@@ -2051,22 +1930,17 @@ class DeepSpeedEngine(Module):
param.grad = None param.grad = None
def clip_fp32_gradients(self): def clip_fp32_gradients(self):
clip_grad_norm_(parameters=self.module.parameters(), clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu)
max_norm=self.gradient_clipping(),
mpu=self.mpu)
def _take_model_step(self, lr_kwargs, block_eigenvalue={}): def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
if self.gradient_clipping() > 0.0: if self.gradient_clipping() > 0.0:
if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()):
or self.zero_optimization()):
self.clip_fp32_gradients() self.clip_fp32_gradients()
elif self.amp_enabled(): elif self.amp_enabled():
# AMP's recommended way of doing clipping # AMP's recommended way of doing clipping
# https://nvidia.github.io/apex/advanced.html#gradient-clipping # https://nvidia.github.io/apex/advanced.html#gradient-clipping
master_params = amp.master_params(self.optimizer) master_params = amp.master_params(self.optimizer)
clip_grad_norm_(parameters=master_params, clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)
max_norm=self.gradient_clipping(),
mpu=self.mpu)
self.optimizer.step() self.optimizer.step()
if hasattr(self.optimizer, '_global_grad_norm'): if hasattr(self.optimizer, '_global_grad_norm'):
...@@ -2087,7 +1961,7 @@ class DeepSpeedEngine(Module): ...@@ -2087,7 +1961,7 @@ class DeepSpeedEngine(Module):
# the behaviour that we want # the behaviour that we want
if self.bfloat16_enabled(): if self.bfloat16_enabled():
# TODO: Temporary until bf16_optimizer and zero_optimizer are integrated # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated
if self.zero_optimization(): if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"):
self.optimizer.zero_grad() self.optimizer.zero_grad()
else: else:
pass pass
...@@ -2132,8 +2006,7 @@ class DeepSpeedEngine(Module): ...@@ -2132,8 +2006,7 @@ class DeepSpeedEngine(Module):
# Check early because self.global_steps is incremented at some point here. # Check early because self.global_steps is incremented at some point here.
# TODO: Delay self.global_steps increment until very end of this function. # TODO: Delay self.global_steps increment until very end of this function.
flops_profiler_active = self.flops_profiler_enabled( flops_profiler_active = self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_profile_step( ) and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0
) and self.global_rank == 0
self._start_timers(self.engine_timers.step_timers) self._start_timers(self.engine_timers.step_timers)
...@@ -2148,20 +2021,16 @@ class DeepSpeedEngine(Module): ...@@ -2148,20 +2021,16 @@ class DeepSpeedEngine(Module):
if self.is_gradient_accumulation_boundary(): if self.is_gradient_accumulation_boundary():
self.gas_boundary_ctr += 1 self.gas_boundary_ctr += 1
if (self.eigenvalue_enabled() and if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0)
(self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0)
and self.quantizer.any_precision_switch()): and self.quantizer.any_precision_switch()):
log_dist(f"computing eigenvalue...", ranks=[0]) log_dist(f"computing eigenvalue...", ranks=[0])
self.block_eigenvalue = self.eigenvalue.compute_eigenvalue( self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device,
self.module, self.optimizer.cur_scale)
self.device,
self.optimizer.cur_scale)
if self.progressive_layer_drop: if self.progressive_layer_drop:
self.progressive_layer_drop.update_state(self.global_steps) self.progressive_layer_drop.update_state(self.global_steps)
if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()
self.eigenvalue_gas_boundary_resolution()
and self.quantizer.any_precision_switch()): and self.quantizer.any_precision_switch()):
self._take_model_step(lr_kwargs, self.block_eigenvalue) self._take_model_step(lr_kwargs, self.block_eigenvalue)
else: else:
...@@ -2169,8 +2038,7 @@ class DeepSpeedEngine(Module): ...@@ -2169,8 +2038,7 @@ class DeepSpeedEngine(Module):
report_progress = self.global_rank == 0 if self.global_rank else True report_progress = self.global_rank == 0 if self.global_rank else True
self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress)
report_speed=report_progress)
self._stop_timers(self.engine_timers.step_timers) self._stop_timers(self.engine_timers.step_timers)
...@@ -2178,9 +2046,7 @@ class DeepSpeedEngine(Module): ...@@ -2178,9 +2046,7 @@ class DeepSpeedEngine(Module):
if self.monitor.enabled: if self.monitor.enabled:
if self.is_gradient_accumulation_boundary(): if self.is_gradient_accumulation_boundary():
if self.global_rank == 0: if self.global_rank == 0:
self.summary_events = [(f"Train/Samples/lr", self.summary_events = [(f"Train/Samples/lr", self.get_lr()[0], self.global_samples)]
self.get_lr()[0],
self.global_samples)]
if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"): if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"):
self.summary_events.append(( self.summary_events.append((
...@@ -2189,8 +2055,8 @@ class DeepSpeedEngine(Module): ...@@ -2189,8 +2055,8 @@ class DeepSpeedEngine(Module):
self.global_samples, self.global_samples,
)) ))
if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % if (self.eigenvalue_enabled()
self.eigenvalue_gas_boundary_resolution()): and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()):
ev_values = self.block_eigenvalue.values() ev_values = self.block_eigenvalue.values()
for i in range(len(ev_values)): for i in range(len(ev_values)):
self.summary_events.append(( self.summary_events.append((
...@@ -2214,14 +2080,12 @@ class DeepSpeedEngine(Module): ...@@ -2214,14 +2080,12 @@ class DeepSpeedEngine(Module):
) )
self.flops_profiler.end_profile() self.flops_profiler.end_profile()
if self.autotuning_enabled() and self.global_steps == ( if self.autotuning_enabled() and self.global_steps == (self.autotuning_end_profile_step() + 1):
self.autotuning_end_profile_step() + 1):
self._autotuning_exit() self._autotuning_exit()
if self.wall_clock_breakdown(): if self.wall_clock_breakdown():
# Log micro timing and reset # Log micro timing and reset
self.timers.log(names=self.engine_timers.micro_timers, self.timers.log(names=self.engine_timers.micro_timers, memory_breakdown=self.memory_breakdown())
memory_breakdown=self.memory_breakdown())
if self.wall_clock_breakdown() or self.flops_profiler_enabled(): if self.wall_clock_breakdown() or self.flops_profiler_enabled():
# Log global timing and reset # Log global timing and reset
...@@ -2255,13 +2119,10 @@ class DeepSpeedEngine(Module): ...@@ -2255,13 +2119,10 @@ class DeepSpeedEngine(Module):
FORWARD_GLOBAL_TIMER, FORWARD_GLOBAL_TIMER,
BACKWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER,
STEP_GLOBAL_TIMER, STEP_GLOBAL_TIMER,
], ], reset=False)
reset=False) titer = msg[FORWARD_GLOBAL_TIMER] + msg[BACKWARD_GLOBAL_TIMER] + msg[STEP_GLOBAL_TIMER]
titer = msg[FORWARD_GLOBAL_TIMER] + msg[BACKWARD_GLOBAL_TIMER] + msg[
STEP_GLOBAL_TIMER]
msg["latency"] = titer msg["latency"] = titer
msg["FLOPS_per_gpu"] = self.flops * 1_000_000 * self.gradient_accumulation_steps( msg["FLOPS_per_gpu"] = self.flops * 1_000_000 * self.gradient_accumulation_steps() / titer
) / titer
msg["throughput"] = self.train_batch_size() * 1_000_000 / \ msg["throughput"] = self.train_batch_size() * 1_000_000 / \
msg["latency"] msg["latency"]
print_json_dist(msg, [0], path=self.autotuning_metric_path()) print_json_dist(msg, [0], path=self.autotuning_metric_path())
...@@ -2335,8 +2196,7 @@ class DeepSpeedEngine(Module): ...@@ -2335,8 +2196,7 @@ class DeepSpeedEngine(Module):
def _report_progress(self, step): def _report_progress(self, step):
lr = self.get_lr() lr = self.get_lr()
mom = self.get_mom() mom = self.get_mom()
log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0])
ranks=[0])
def allreduce_bucket(self, bucket, dp_group): def allreduce_bucket(self, bucket, dp_group):
tensor = self.flatten(bucket) tensor = self.flatten(bucket)
...@@ -2352,10 +2212,8 @@ class DeepSpeedEngine(Module): ...@@ -2352,10 +2212,8 @@ class DeepSpeedEngine(Module):
dist.all_reduce(tensor_to_allreduce, group=dp_group) dist.all_reduce(tensor_to_allreduce, group=dp_group)
if self.gradient_average: if self.gradient_average:
if self.gradient_predivide_factor() != dist.get_world_size( if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group):
group=dp_group): tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))
tensor_to_allreduce.mul_(self.gradient_predivide_factor() /
dist.get_world_size(group=dp_group))
else: else:
tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group)) tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group))
dist.all_reduce(tensor_to_allreduce, group=dp_group) dist.all_reduce(tensor_to_allreduce, group=dp_group)
...@@ -2397,9 +2255,7 @@ class DeepSpeedEngine(Module): ...@@ -2397,9 +2255,7 @@ class DeepSpeedEngine(Module):
# rank is reducing the same size. In some cases it may make # rank is reducing the same size. In some cases it may make
# sense in the future to support the ability to average not # sense in the future to support the ability to average not
# w.r.t. world size but with a different value. # w.r.t. world size but with a different value.
param.grad = torch.zeros(param.size(), param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device)
dtype=param.dtype,
device=param.device)
grad_data = param.grad.data grad_data = param.grad.data
if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: if param_name in self.sparse_tensor_module_names or grad_data.is_sparse:
...@@ -2426,9 +2282,7 @@ class DeepSpeedEngine(Module): ...@@ -2426,9 +2282,7 @@ class DeepSpeedEngine(Module):
if bucket_type == SparseTensor.type(): if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(bucket, dp_group=dp_group) self.sparse_allreduce_no_retain(bucket, dp_group=dp_group)
else: else:
self.allreduce_no_retain(bucket, self.allreduce_no_retain(bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer)
dp_group=dp_group,
numel_per_bucket=elements_per_buffer)
def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): def _reduce_expert_gradients(self, expert_grads, elements_per_buffer):
for ep_name, expert_grads_group in expert_grads.items(): for ep_name, expert_grads_group in expert_grads.items():
...@@ -2436,15 +2290,12 @@ class DeepSpeedEngine(Module): ...@@ -2436,15 +2290,12 @@ class DeepSpeedEngine(Module):
for i, bucket_tuple in enumerate(expert_split_buckets): for i, bucket_tuple in enumerate(expert_split_buckets):
bucket_type, bucket = bucket_tuple bucket_type, bucket = bucket_tuple
if bucket_type == SparseTensor.type(): if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain( self.sparse_allreduce_no_retain(bucket, groups._get_expert_data_parallel_group(ep_name))
bucket,
groups._get_expert_data_parallel_group(ep_name))
else: else:
# Separate between diff groups # Separate between diff groups
self.allreduce_no_retain( self.allreduce_no_retain(bucket,
bucket, dp_group=groups._get_expert_data_parallel_group(ep_name),
dp_group=groups._get_expert_data_parallel_group(ep_name), numel_per_bucket=elements_per_buffer)
numel_per_bucket=elements_per_buffer)
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
if grads is None: if grads is None:
...@@ -2487,8 +2338,7 @@ class DeepSpeedEngine(Module): ...@@ -2487,8 +2338,7 @@ class DeepSpeedEngine(Module):
if self.postscale_gradients(): if self.postscale_gradients():
if self.gradient_average: if self.gradient_average:
values.mul_(self.gradient_predivide_factor() / values.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))
dist.get_world_size(group=dp_group))
else: else:
values.mul_(1. / dist.get_world_size(group=dp_group)) values.mul_(1. / dist.get_world_size(group=dp_group))
...@@ -2509,36 +2359,25 @@ class DeepSpeedEngine(Module): ...@@ -2509,36 +2359,25 @@ class DeepSpeedEngine(Module):
if value.dim() == 1: if value.dim() == 1:
if fill_size > 0: if fill_size > 0:
value = torch.cat([value, value.new_empty(fill_size)]) value = torch.cat([value, value.new_empty(fill_size)])
tensor_list = [ tensor_list = [value.new_empty(max_size) for _ in range(dist.get_world_size(group=dp_group))]
value.new_empty(max_size)
for _ in range(dist.get_world_size(group=dp_group))
]
else: else:
if fill_size > 0: if fill_size > 0:
value = torch.cat([value, value.new_empty(fill_size, value.size()[1])]) value = torch.cat([value, value.new_empty(fill_size, value.size()[1])])
tensor_list = [ tensor_list = [
value.new_empty(max_size, value.new_empty(max_size,
value.size()[1]) value.size()[1]) for _ in range(dist.get_world_size(group=dp_group))
for _ in range(dist.get_world_size(group=dp_group))
] ]
dist.all_gather(tensor_list, value, group=dp_group) dist.all_gather(tensor_list, value, group=dp_group)
tensors = [] tensors = []
for dev_idx, t in enumerate(tensor_list): for dev_idx, t in enumerate(tensor_list):
size = all_sizes[dev_idx][0] size = all_sizes[dev_idx][0]
tensors.append( tensors.append(t.index_select(0, torch.arange(size, dtype=torch.long, device=self.device)))
t.index_select(0,
torch.arange(size,
dtype=torch.long,
device=self.device)))
return tensors return tensors
def all_gather_scalar(self, value, dp_group): def all_gather_scalar(self, value, dp_group):
tensor_list = [ tensor_list = [value.new_zeros(value.size()) for _ in range(dist.get_world_size(group=dp_group))]
value.new_zeros(value.size())
for _ in range(dist.get_world_size(group=dp_group))
]
dist.all_gather(tensor_list, value, group=dp_group) dist.all_gather(tensor_list, value, group=dp_group)
return tensor_list return tensor_list
...@@ -2558,20 +2397,19 @@ class DeepSpeedEngine(Module): ...@@ -2558,20 +2397,19 @@ class DeepSpeedEngine(Module):
num_experts=1, num_experts=1,
checkpoint_engine=TorchCheckpointEngine()): checkpoint_engine=TorchCheckpointEngine()):
if old_moe_load: if old_moe_load:
expp_rank = groups._get_expert_data_parallel_rank( expp_rank = groups._get_expert_data_parallel_rank(groups._get_max_expert_size_name())
groups._get_max_expert_size_name())
num_local_experts = max( num_local_experts = max(num_experts) // groups._get_expert_parallel_world_size(
num_experts) // groups._get_expert_parallel_world_size( groups._get_max_expert_size_name())
groups._get_max_expert_size_name())
for local_expert_id in range(num_local_experts): for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id global_expert_id = expp_rank * num_local_experts + local_expert_id
expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name( expert_state_dict = checkpoint_engine.load(
checkpoint_path, DeepSpeedEngine._get_expert_ckpt_name(
-1, # -1 means ignore layer_id checkpoint_path,
global_expert_id, -1, # -1 means ignore layer_id
tag, global_expert_id,
mpu), tag,
mpu),
map_location=torch.device('cpu')) map_location=torch.device('cpu'))
# Updating global -> local expert ids # Updating global -> local expert ids
...@@ -2592,41 +2430,45 @@ class DeepSpeedEngine(Module): ...@@ -2592,41 +2430,45 @@ class DeepSpeedEngine(Module):
# loop all local_experts # loop all local_experts
for local_expert_id in range(num_local_experts): for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id global_expert_id = expp_rank * num_local_experts + local_expert_id
expert_state_dict = checkpoint_engine.load( expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name(
DeepSpeedEngine._get_expert_ckpt_name( checkpoint_path, moe_layer_id, global_expert_id, tag, mpu),
checkpoint_path, map_location=torch.device('cpu'))
moe_layer_id,
global_expert_id,
tag,
mpu),
map_location=torch.device('cpu'))
# print(expert_state_dict.keys()) # print(expert_state_dict.keys())
# Updating global -> local expert ids # Updating global -> local expert ids
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
for key in list(expert_state_dict.keys()): for key in list(expert_state_dict.keys()):
local_key = key.replace( local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
f'{moe_str_prefix}{global_expert_id}', f'{moe_str_prefix}{local_expert_id}')
f'{moe_str_prefix}{local_expert_id}')
expert_state_dict[local_key] = expert_state_dict.pop(key) expert_state_dict[local_key] = expert_state_dict.pop(key)
state_dict.update(expert_state_dict) state_dict.update(expert_state_dict)
moe_layer_id += 1 moe_layer_id += 1
def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None):
module_state_dict = checkpoint['module']
if custom_load_fn: if custom_load_fn:
custom_load_fn(src=state_dict, dst=self.module) custom_load_fn(src=module_state_dict, dst=self.module)
else: else:
self.module.load_state_dict(state_dict, # TODO self.module.load_state_dict(
strict=strict) module_state_dict, # TODO
strict=strict)
if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None:
saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS]
for param in self.module.parameters():
if param.requires_grad:
continue
if param not in self.param_names:
raise ValueError(f"failed to find frozen {param} in named params")
name = self.param_names[param]
if hasattr(param, 'ds_id'):
param.ds_tensor.data.copy_(saved_frozen_params[name].data)
else:
param.data.copy_(saved_frozen_params[name].data)
def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode): def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode):
return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}' return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}'
def _get_rank_zero_ckpt_name(self, def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank, bf16_mode):
checkpoints_path,
tag,
mp_rank,
dp_rank,
bf16_mode):
file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode) file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode)
zero_ckpt_name = os.path.join( zero_ckpt_name = os.path.join(
checkpoints_path, checkpoints_path,
...@@ -2639,11 +2481,7 @@ class DeepSpeedEngine(Module): ...@@ -2639,11 +2481,7 @@ class DeepSpeedEngine(Module):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
pp_rank = dist.get_rank(group=self.optimizer.dp_process_group) pp_rank = dist.get_rank(group=self.optimizer.dp_process_group)
bf16_mode = self.bfloat16_enabled() bf16_mode = self.bfloat16_enabled()
return self._get_rank_zero_ckpt_name(checkpoints_path, return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode)
tag,
mp_rank,
pp_rank,
bf16_mode)
def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None): def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
if mp_placeholder is not None: if mp_placeholder is not None:
...@@ -2653,8 +2491,7 @@ class DeepSpeedEngine(Module): ...@@ -2653,8 +2491,7 @@ class DeepSpeedEngine(Module):
mp_rank_str = f"{mp_rank:02d}" mp_rank_str = f"{mp_rank:02d}"
if self.zero_optimization_partition_weights(): if self.zero_optimization_partition_weights():
filename = "zero_pp_rank_{}".format( filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
dist.get_rank(group=self.optimizer.dp_process_group))
ckpt_name = os.path.join( ckpt_name = os.path.join(
checkpoints_path, checkpoints_path,
str(tag), str(tag),
...@@ -2670,10 +2507,8 @@ class DeepSpeedEngine(Module): ...@@ -2670,10 +2507,8 @@ class DeepSpeedEngine(Module):
def _get_optimizer_ckpt_name(self, checkpoints_path, tag, expp_rank): def _get_optimizer_ckpt_name(self, checkpoints_path, tag, expp_rank):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
ckpt_name = os.path.join( ckpt_name = os.path.join(checkpoints_path, str(tag),
checkpoints_path, f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt')
str(tag),
f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt')
return ckpt_name return ckpt_name
@staticmethod @staticmethod
...@@ -2681,24 +2516,17 @@ class DeepSpeedEngine(Module): ...@@ -2681,24 +2516,17 @@ class DeepSpeedEngine(Module):
mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank()
if layer_id <= -1: if layer_id <= -1:
# Used to support old checkpoint loading # Used to support old checkpoint loading
ckpt_name = os.path.join( ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag),
checkpoints_path, f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
'' if tag is None else str(tag),
f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
else: else:
# Used to support new checkpoint loading # Used to support new checkpoint loading
ckpt_name = os.path.join( ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag),
checkpoints_path, f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt')
'' if tag is None else str(tag),
f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt'
)
return ckpt_name return ckpt_name
def _get_all_ckpt_names(self, checkpoints_path, tag): def _get_all_ckpt_names(self, checkpoints_path, tag):
# It is required that (checkpoints_path, tag) are consistent among all ranks. # It is required that (checkpoints_path, tag) are consistent among all ranks.
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
tag,
mp_placeholder="*")
import glob import glob
ckpt_files = glob.glob(ckpt_file_pattern) ckpt_files = glob.glob(ckpt_file_pattern)
...@@ -2738,17 +2566,14 @@ class DeepSpeedEngine(Module): ...@@ -2738,17 +2566,14 @@ class DeepSpeedEngine(Module):
""" """
if tag is None: if tag is None:
latest_tag = "latest_universal" if self.load_universal_checkpoint( latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest"
) else "latest"
latest_path = os.path.join(load_dir, latest_tag) latest_path = os.path.join(load_dir, latest_tag)
if os.path.isfile(latest_path): if os.path.isfile(latest_path):
with open(latest_path, "r") as fd: with open(latest_path, "r") as fd:
tag = fd.read().strip() tag = fd.read().strip()
else: else:
if self.load_universal_checkpoint(): if self.load_universal_checkpoint():
raise ValueError( raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist')
f'Invalid for universal checkpoint: {latest_path} does not exist'
)
else: else:
logger.warning( logger.warning(
f"Unable to find latest file at {latest_path}, if trying to load latest " f"Unable to find latest file at {latest_path}, if trying to load latest "
...@@ -2770,10 +2595,7 @@ class DeepSpeedEngine(Module): ...@@ -2770,10 +2595,7 @@ class DeepSpeedEngine(Module):
load_zero_checkpoint = self.zero_optimization() or self.bfloat16_enabled() load_zero_checkpoint = self.zero_optimization() or self.bfloat16_enabled()
if load_zero_checkpoint and load_path is not None: if load_zero_checkpoint and load_path is not None:
success = self._load_zero_checkpoint( success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
load_dir,
tag,
load_optimizer_states=load_optimizer_states)
if not success: if not success:
self.optimizer._restore_from_bit16_weights() self.optimizer._restore_from_bit16_weights()
...@@ -2794,16 +2616,12 @@ class DeepSpeedEngine(Module): ...@@ -2794,16 +2616,12 @@ class DeepSpeedEngine(Module):
from deepspeed.runtime.state_dict_factory import SDLoaderFactory from deepspeed.runtime.state_dict_factory import SDLoaderFactory
ckpt_list = self._get_all_ckpt_names(load_dir, tag) ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader( sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine)
ckpt_list,
checkpoint_engine=self.checkpoint_engine)
is_pipe_parallel = isinstance(self.module, PipelineModule) is_pipe_parallel = isinstance(self.module, PipelineModule)
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
load_path, checkpoint, _ = sd_loader.load( load_path, checkpoint, _ = sd_loader.load(self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel)
self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel
)
if checkpoint is None: if checkpoint is None:
return None, None return None, None
...@@ -2826,7 +2644,7 @@ class DeepSpeedEngine(Module): ...@@ -2826,7 +2644,7 @@ class DeepSpeedEngine(Module):
num_experts=self.num_experts, num_experts=self.num_experts,
checkpoint_engine=self.checkpoint_engine) checkpoint_engine=self.checkpoint_engine)
if not self.load_universal_checkpoint(): if not self.load_universal_checkpoint():
self.load_module_state_dict(state_dict=checkpoint['module'], self.load_module_state_dict(checkpoint=checkpoint,
strict=load_module_strict, strict=load_module_strict,
custom_load_fn=custom_load_fn) custom_load_fn=custom_load_fn)
...@@ -2841,38 +2659,29 @@ class DeepSpeedEngine(Module): ...@@ -2841,38 +2659,29 @@ class DeepSpeedEngine(Module):
largest_group_name = groups._get_max_expert_size_name() largest_group_name = groups._get_max_expert_size_name()
expp_rank = groups._get_expert_parallel_rank(largest_group_name) expp_rank = groups._get_expert_parallel_rank(largest_group_name)
optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank)
optim_checkpoint = self.checkpoint_engine.load( optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu'))
optim_load_path,
map_location=torch.device('cpu'))
else: else:
optim_checkpoint = checkpoint optim_checkpoint = checkpoint
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled( has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
)
if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state: if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state:
if self.fp16_enabled(): if self.fp16_enabled():
self.optimizer.load_state_dict( self.optimizer.load_state_dict(optim_checkpoint['optimizer'],
optim_checkpoint['optimizer'], load_optimizer_states=load_optimizer_states)
load_optimizer_states=load_optimizer_states)
else: else:
self.optimizer.load_state_dict(optim_checkpoint['optimizer']) self.optimizer.load_state_dict(optim_checkpoint['optimizer'])
if load_lr_scheduler_states and self.lr_scheduler is not None: if load_lr_scheduler_states and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
if self.random_ltd_enabled( if self.random_ltd_enabled() and self.random_ltd_scheduler is not None and 'random_ltd' in checkpoint:
) and self.random_ltd_scheduler is not None and 'random_ltd' in checkpoint:
self.random_ltd_scheduler.load_state_dict(checkpoint['random_ltd']) self.random_ltd_scheduler.load_state_dict(checkpoint['random_ltd'])
if self.training_dataloader is not None and self.curriculum_learning_enabled( if self.training_dataloader is not None and self.curriculum_learning_enabled(
) and 'data_sampler' in checkpoint: ) and 'data_sampler' in checkpoint:
self.training_dataloader.data_sampler.load_state_dict( self.training_dataloader.data_sampler.load_state_dict(checkpoint['data_sampler'])
checkpoint['data_sampler'])
def get_sparse_tensor_module_names(original_set, def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters, loaded_parameters):
loaded_set,
original_parameters,
loaded_parameters):
result = set() result = set()
for name in original_set: for name in original_set:
...@@ -2882,8 +2691,7 @@ class DeepSpeedEngine(Module): ...@@ -2882,8 +2691,7 @@ class DeepSpeedEngine(Module):
for name in loaded_set: for name in loaded_set:
if name in original_parameters: if name in original_parameters:
result.add( result.add(name) # parameter exists in both configs and it was sparse
name) # parameter exists in both configs and it was sparse
return result return result
...@@ -2898,26 +2706,16 @@ class DeepSpeedEngine(Module): ...@@ -2898,26 +2706,16 @@ class DeepSpeedEngine(Module):
self.sparse_tensor_module_names = sparse_tensor_module_names self.sparse_tensor_module_names = sparse_tensor_module_names
else: else:
self.sparse_tensor_module_names = get_sparse_tensor_module_names( self.sparse_tensor_module_names = get_sparse_tensor_module_names(
self.sparse_tensor_module_names, self.sparse_tensor_module_names, sparse_tensor_module_names,
sparse_tensor_module_names, dict(self.module.named_parameters()), checkpoint["module"])
dict(self.module.named_parameters()),
checkpoint["module"])
self.global_steps = checkpoint['global_steps'] self.global_steps = checkpoint['global_steps']
self.global_samples = checkpoint.get( self.global_samples = checkpoint.get('global_samples', self.global_steps * self.train_batch_size())
'global_samples',
self.global_steps * self.train_batch_size())
self.skipped_steps = checkpoint['skipped_steps'] self.skipped_steps = checkpoint['skipped_steps']
self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
deepspeed_states = [ deepspeed_states = [
'module', 'module', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'dp_world_size',
'sparse_tensor_module_names', 'mp_world_size', 'data_sampler', 'random_ltd'
'skipped_steps',
'global_steps',
'dp_world_size',
'mp_world_size',
'data_sampler',
'random_ltd'
] ]
client_state = {} client_state = {}
...@@ -2926,11 +2724,7 @@ class DeepSpeedEngine(Module): ...@@ -2926,11 +2724,7 @@ class DeepSpeedEngine(Module):
if load_optimizer_states: if load_optimizer_states:
deepspeed_states.append('optimizer') deepspeed_states.append('optimizer')
client_state = { client_state = {key: value for key, value in checkpoint.items() if not key in deepspeed_states}
key: value
for key,
value in checkpoint.items() if not key in deepspeed_states
}
if not load_optimizer_states and not load_module_only: if not load_optimizer_states and not load_module_only:
client_state['optimizer'] = optim_checkpoint['optimizer'] client_state['optimizer'] = optim_checkpoint['optimizer']
...@@ -2953,28 +2747,18 @@ class DeepSpeedEngine(Module): ...@@ -2953,28 +2747,18 @@ class DeepSpeedEngine(Module):
if zero_sd_list is None: if zero_sd_list is None:
return False return False
self.optimizer.load_state_dict( self.optimizer.load_state_dict(state_dict_list=zero_sd_list,
state_dict_list=zero_sd_list, load_optimizer_states=load_optimizer_states,
load_optimizer_states=load_optimizer_states, load_from_fp32_weights=self.zero_load_from_fp32_weights(),
load_from_fp32_weights=self.zero_load_from_fp32_weights(), checkpoint_folder=checkpoint_folder)
checkpoint_folder=checkpoint_folder)
if self.load_universal_checkpoint(): if self.load_universal_checkpoint():
logger.info( logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}')
f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}'
)
else: else:
logger.info( logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}")
f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}"
)
return True return True
def _get_mp_rank_zero_checkpoint_names(self, def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
load_dir,
tag,
mp_rank,
dp_world_size,
bf16_mode):
zero_ckpt_names = [] zero_ckpt_names = []
for dp_rank in range(dp_world_size): for dp_rank in range(dp_world_size):
ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir, ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir,
...@@ -2988,18 +2772,16 @@ class DeepSpeedEngine(Module): ...@@ -2988,18 +2772,16 @@ class DeepSpeedEngine(Module):
def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode): def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names( zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(load_dir=load_dir,
load_dir=load_dir, tag=tag,
tag=tag, mp_rank=mp_rank,
mp_rank=mp_rank, dp_world_size=self.loaded_checkpoint_dp_world_size,
dp_world_size=self.loaded_checkpoint_dp_world_size, bf16_mode=bf16_mode)
bf16_mode=bf16_mode)
for i, ckpt_name in enumerate(zero_ckpt_names): for i, ckpt_name in enumerate(zero_ckpt_names):
if not os.path.exists(ckpt_name): if not os.path.exists(ckpt_name):
# transparently handle the old file pattern for optim_states # transparently handle the old file pattern for optim_states
if "optim_states.pt" in ckpt_name: if "optim_states.pt" in ckpt_name:
ckpt_name_try = ckpt_name.replace("_optim_states.pt", ckpt_name_try = ckpt_name.replace("_optim_states.pt", "optim_states.pt")
"optim_states.pt")
if os.path.exists(ckpt_name_try): if os.path.exists(ckpt_name_try):
zero_ckpt_names[i] = ckpt_name_try zero_ckpt_names[i] = ckpt_name_try
continue continue
...@@ -3013,8 +2795,7 @@ class DeepSpeedEngine(Module): ...@@ -3013,8 +2795,7 @@ class DeepSpeedEngine(Module):
if ckpt_name is None: if ckpt_name is None:
_state = {OPTIMIZER_STATE_DICT: None} _state = {OPTIMIZER_STATE_DICT: None}
# Fully load state for current rank # Fully load state for current rank
elif self.zero_elastic_checkpoint() or dist.get_rank( elif self.zero_elastic_checkpoint() or dist.get_rank(group=self.optimizer.dp_process_group) == i:
group=self.optimizer.dp_process_group) == i:
_state = self.checkpoint_engine.load( _state = self.checkpoint_engine.load(
ckpt_name, ckpt_name,
map_location='cpu', map_location='cpu',
...@@ -3024,25 +2805,18 @@ class DeepSpeedEngine(Module): ...@@ -3024,25 +2805,18 @@ class DeepSpeedEngine(Module):
zero_sd_list.append(_state) zero_sd_list.append(_state)
zero_optimizer_sd = [sd[OPTIMIZER_STATE_DICT] for sd in zero_sd_list] zero_optimizer_sd = [sd[OPTIMIZER_STATE_DICT] for sd in zero_sd_list]
logger.info( logger.info(f"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}")
f"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}"
)
return zero_optimizer_sd return zero_optimizer_sd
def _get_all_zero_checkpoints(self, load_dir, tag): def _get_all_zero_checkpoints(self, load_dir, tag):
for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]: for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]:
zero_ckpt_names = self._get_all_zero_checkpoint_names( zero_ckpt_names = self._get_all_zero_checkpoint_names(load_dir, tag, bf16_mode)
load_dir,
tag,
bf16_mode)
if zero_ckpt_names is not None: if zero_ckpt_names is not None:
# Warn if loading checkpoint of different bit16 type # Warn if loading checkpoint of different bit16 type
if bf16_mode is not self.bfloat16_enabled(): if bf16_mode is not self.bfloat16_enabled():
checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16 checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16
engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16 engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16
logger.warn( logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine'
)
return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names) return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)
return None return None
...@@ -3056,10 +2830,9 @@ class DeepSpeedEngine(Module): ...@@ -3056,10 +2830,9 @@ class DeepSpeedEngine(Module):
dist.all_reduce(max_bhash, op=dist.ReduceOp.MAX) dist.all_reduce(max_bhash, op=dist.ReduceOp.MAX)
dist.all_reduce(min_bhash, op=dist.ReduceOp.MIN) dist.all_reduce(min_bhash, op=dist.ReduceOp.MIN)
valid = all(min_bhash == bhash) and all(max_bhash == bhash) valid = all(min_bhash == bhash) and all(max_bhash == bhash)
msg = ( msg = (f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across "
f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across " "all ranks. Including rank unique information in checkpoint tag could cause issues when "
"all ranks. Including rank unique information in checkpoint tag could cause issues when " "restoring with different world sizes.")
"restoring with different world sizes.")
if self.checkpoint_tag_validation_fail(): if self.checkpoint_tag_validation_fail():
assert valid, msg assert valid, msg
elif not valid: elif not valid:
...@@ -3090,7 +2863,7 @@ class DeepSpeedEngine(Module): ...@@ -3090,7 +2863,7 @@ class DeepSpeedEngine(Module):
# There seems to be issue creating them in parallel # There seems to be issue creating them in parallel
# Ensure save_dir directory exists # Ensure save_dir directory exists
os.makedirs(save_dir, exist_ok=True) self.checkpoint_engine.makedirs(save_dir, exist_ok=True)
dist.barrier() dist.barrier()
if tag is None: if tag is None:
...@@ -3191,15 +2964,9 @@ class DeepSpeedEngine(Module): ...@@ -3191,15 +2964,9 @@ class DeepSpeedEngine(Module):
# let save the moe parameters # let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items(): for global_expert_id, expert_state_dict in experts_state_dict.items():
# save the moe parameters # save the moe parameters
moe_save_path = self._get_expert_ckpt_name( moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
save_dir,
moe_layer_id,
global_expert_id,
tag,
self.mpu)
if self.random_ltd_enabled(): if self.random_ltd_enabled():
expert_state_dict = remove_random_ltd_state_dict( expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
expert_state_dict)
self.checkpoint_engine.save(expert_state_dict, moe_save_path) self.checkpoint_engine.save(expert_state_dict, moe_save_path)
moe_layer_id += 1 moe_layer_id += 1
...@@ -3217,9 +2984,7 @@ class DeepSpeedEngine(Module): ...@@ -3217,9 +2984,7 @@ class DeepSpeedEngine(Module):
# Save optimizer states. They are different across each exp parallel rank. # Save optimizer states. They are different across each exp parallel rank.
optimizer_state = { optimizer_state = {
'optimizer': 'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None
self.optimizer.state_dict()
if self.optimizer and not self.zero_optimization() else None
} }
# TODO: why use BufferedWriter not the path # TODO: why use BufferedWriter not the path
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
...@@ -3234,15 +2999,12 @@ class DeepSpeedEngine(Module): ...@@ -3234,15 +2999,12 @@ class DeepSpeedEngine(Module):
'module': 'module':
model_state_dict, model_state_dict,
'lr_scheduler': 'lr_scheduler':
self.lr_scheduler.state_dict() self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
if self.lr_scheduler is not None else None,
'data_sampler': 'data_sampler':
self.training_dataloader.data_sampler.state_dict() if self.training_dataloader.data_sampler.state_dict() if
(self.training_dataloader is not None (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
and self.curriculum_learning_enabled()) else None,
'random_ltd': 'random_ltd':
self.random_ltd_scheduler.state_dict() self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,
if self.random_ltd_enabled() else None,
'sparse_tensor_module_names': 'sparse_tensor_module_names':
self.sparse_tensor_module_names, self.sparse_tensor_module_names,
'skipped_steps': 'skipped_steps':
...@@ -3264,11 +3026,11 @@ class DeepSpeedEngine(Module): ...@@ -3264,11 +3026,11 @@ class DeepSpeedEngine(Module):
self._curr_save_path = None self._curr_save_path = None
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
name_function = (self._get_zero_ckpt_name name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name)
if zero_checkpoint else self._get_ckpt_name)
try: try:
checkpoint_name = name_function(save_dir, tag) checkpoint_name = name_function(save_dir, tag)
ensure_directory_exists(checkpoint_name) path = os.path.dirname(checkpoint_name)
self.checkpoint_engine.makedirs(path, exist_ok=True)
except: except:
logger.error(f"Failed saving model checkpoint to {save_dir} with tag {tag}") logger.error(f"Failed saving model checkpoint to {save_dir} with tag {tag}")
return False return False
...@@ -3292,6 +3054,8 @@ class DeepSpeedEngine(Module): ...@@ -3292,6 +3054,8 @@ class DeepSpeedEngine(Module):
zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
save_frozen_param = self.zero_optimization_partition_gradients()
# A hack to save the checkpointing directory. Pipeline parallelism overrides # A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict() # module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns None. The module_state_dict() implementation in # then instead just returns None. The module_state_dict() implementation in
...@@ -3302,17 +3066,17 @@ class DeepSpeedEngine(Module): ...@@ -3302,17 +3066,17 @@ class DeepSpeedEngine(Module):
state = dict(module=module, state = dict(module=module,
buffer_names=self._get_buffer_names(), buffer_names=self._get_buffer_names(),
optimizer=self.optimizer.state_dict() optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None,
if self.optimizer and not zero_optimizer_state else None, param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None,
param_shapes=self._get_zero_param_shapes() frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func)
if self.optimizer and zero_optimizer_state else None, if save_frozen_param else None,
lr_scheduler=self.lr_scheduler.state_dict() shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None,
if self.lr_scheduler is not None else None, frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func)
if save_frozen_param else None,
lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
data_sampler=self.training_dataloader.data_sampler.state_dict() if data_sampler=self.training_dataloader.data_sampler.state_dict() if
(self.training_dataloader is not None (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
and self.curriculum_learning_enabled()) else None, random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,
random_ltd=self.random_ltd_scheduler.state_dict()
if self.random_ltd_enabled() else None,
sparse_tensor_module_names=self.sparse_tensor_module_names, sparse_tensor_module_names=self.sparse_tensor_module_names,
skipped_steps=self.skipped_steps, skipped_steps=self.skipped_steps,
global_steps=self.global_steps, global_steps=self.global_steps,
...@@ -3348,6 +3112,25 @@ class DeepSpeedEngine(Module): ...@@ -3348,6 +3112,25 @@ class DeepSpeedEngine(Module):
return buffer_names return buffer_names
def _get_param_shape_func(self, param):
return param.ds_shape if hasattr(param, 'ds_id') else param.shape
def _get_param_fragment_func(self, param):
return param.ds_tensor.detach().cpu() if hasattr(param, 'ds_id') else param.detach().cpu()
def _get_zero_frozen_param_attributes(self, attr_func):
frozen_param_fragments = OrderedDict()
for param in self.module.parameters():
if param.requires_grad:
continue
if param not in self.param_names:
raise ValueError(f"failed to find frozen {param} in named params")
name = self.param_names[param]
frozen_param_fragments[name] = attr_func(param)
return frozen_param_fragments
def _get_zero_param_shapes(self): def _get_zero_param_shapes(self):
"""Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the """Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the
optimizer. the names are exactly as in state_dict. The order is absolutely important, since optimizer. the names are exactly as in state_dict. The order is absolutely important, since
...@@ -3390,6 +3173,40 @@ class DeepSpeedEngine(Module): ...@@ -3390,6 +3173,40 @@ class DeepSpeedEngine(Module):
return param_group_shapes return param_group_shapes
def _get_shared_params(self):
"""
Returns a dict of shared params, which can later be used to reconstruct the original state dict,
e.g. in `zero_to_fp32`. Each dict entry is a pair of param names, where the key is the name
of the variable that isn't stored and the value is the actual param holding data.
"""
shared_ds_ids = {}
shared_params_by_full_name = {}
def get_layer_state_dict(module, prefix=""):
# handle params
for name, param in module.named_parameters(recurse=False):
if param is None or not hasattr(param, "ds_id"):
continue
key = prefix + name
# can't rely on param.data_ptr() as it will be reused as weights gets
# gathered and reduced, but param.ds_id is unique across all zero weights
# (and shared params will have the same param.ds_id)
if param.ds_id in shared_ds_ids:
# shared weights
#print(f"`{key}` is shared with `{shared_ds_ids[param.ds_id]}`")
shared_params_by_full_name[key] = shared_ds_ids[param.ds_id]
else:
shared_ds_ids[param.ds_id] = key
for name, child in module.named_children():
if child is not None:
get_layer_state_dict(child, prefix + name + ".")
if dist.get_rank() == 0:
get_layer_state_dict(self.module, prefix="")
return shared_params_by_full_name
def _copy_recovery_script(self, save_path): def _copy_recovery_script(self, save_path):
base_dir = os.path.dirname(os.path.dirname(__file__)) base_dir = os.path.dirname(os.path.dirname(__file__))
script = "zero_to_fp32.py" script = "zero_to_fp32.py"
...@@ -3402,9 +3219,7 @@ class DeepSpeedEngine(Module): ...@@ -3402,9 +3219,7 @@ class DeepSpeedEngine(Module):
def _save_zero_checkpoint(self, save_path, tag): def _save_zero_checkpoint(self, save_path, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version)
ds_config=self.config,
ds_version=version)
self.checkpoint_engine.save(zero_sd, zero_checkpoint_name) self.checkpoint_engine.save(zero_sd, zero_checkpoint_name)
if self.global_rank == 0: if self.global_rank == 0:
...@@ -3434,9 +3249,7 @@ class DeepSpeedEngine(Module): ...@@ -3434,9 +3249,7 @@ class DeepSpeedEngine(Module):
# gather one layer at a time to be memory-efficient # gather one layer at a time to be memory-efficient
# must use modifier_rank=0 to release GPU memory after each layer gathered # must use modifier_rank=0 to release GPU memory after each layer gathered
#see_memory_usage("before GatheredParameters", force=True) #see_memory_usage("before GatheredParameters", force=True)
with deepspeed.zero.GatheredParameters(list( with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
module.parameters(recurse=False)),
modifier_rank=0):
if dist.get_rank() == 0: if dist.get_rank() == 0:
# handle params # handle params
for name, param in module.named_parameters(recurse=False): for name, param in module.named_parameters(recurse=False):
...@@ -3457,8 +3270,7 @@ class DeepSpeedEngine(Module): ...@@ -3457,8 +3270,7 @@ class DeepSpeedEngine(Module):
# now buffers - not sure if need to take care of potentially shared weights here # now buffers - not sure if need to take care of potentially shared weights here
for name, buf in module.named_buffers(recurse=False): for name, buf in module.named_buffers(recurse=False):
if (buf is not None if (buf is not None and name not in module._non_persistent_buffers_set):
and name not in module._non_persistent_buffers_set):
state_dict[prefix + name] = buf.detach().cpu() state_dict[prefix + name] = buf.detach().cpu()
#see_memory_usage("after GatheredParameters", force=True) #see_memory_usage("after GatheredParameters", force=True)
...@@ -3511,15 +3323,29 @@ class DeepSpeedEngine(Module): ...@@ -3511,15 +3323,29 @@ class DeepSpeedEngine(Module):
else: else:
# the model will be bogus if not consolidated so don't confuse the user by saving it # the model will be bogus if not consolidated so don't confuse the user by saving it
logger.info( logger.info(
f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False" f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False")
)
return False return False
else: else:
state_dict = self.module.state_dict() state_dict = self.module.state_dict()
tag = f"global_step{self.global_steps}"
tag = str(tag)
self.checkpoint_engine.create(tag)
if dist.get_rank() == 0: if dist.get_rank() == 0:
os.makedirs(save_dir, exist_ok=True) self.checkpoint_engine.makedirs(save_dir, exist_ok=True)
logger.info(f"Saving model weights to {path}") logger.info(f"Saving model weights to {path}, tag: {tag}")
self.checkpoint_engine.save(state_dict, path) self.checkpoint_engine.save(state_dict, path)
self.checkpoint_engine.commit(tag)
return True return True
def empty_partition_cache(self):
"""
Release GPU memory consumed by offloaded model parameters.
"""
if hasattr(self.optimizer, 'empty_partition_cache'):
self.optimizer.empty_partition_cache()
gc.collect()
get_accelerator().empty_cache()
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
''' # Copyright (c) Microsoft Corporation.
Copyright 2019 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/apex Copyright NVIDIA/apex
This file is adapted from FP16_Optimizer in NVIDIA/apex This file is adapted from FP16_Optimizer in NVIDIA/apex
''' """
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
...@@ -23,6 +25,7 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -23,6 +25,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
For usage example please see, TODO: DeepSpeed V2 Tutorial For usage example please see, TODO: DeepSpeed V2 Tutorial
""" """
def __init__(self, def __init__(self,
init_optimizer, init_optimizer,
deepspeed=None, deepspeed=None,
...@@ -58,20 +61,15 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -58,20 +61,15 @@ class FP16_Optimizer(DeepSpeedOptimizer):
# push this group to list before modify # push this group to list before modify
self.fp16_groups.append(param_group['params']) self.fp16_groups.append(param_group['params'])
# init fp16 weight buffer, flattened # init fp16 weight buffer, flattened
self.fp16_groups_flat.append( self.fp16_groups_flat.append(_flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]]))
_flatten_dense_tensors([p.clone().detach()
for p in self.fp16_groups[i]]))
# set model fp16 weight to slices of flattened buffer # set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params): for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data p.data = q.data
# init master weight, flattened # init master weight, flattened
self.fp32_groups_flat.append( self.fp32_groups_flat.append(self.fp16_groups_flat[i].clone().float().detach())
self.fp16_groups_flat[i].clone().float().detach())
# modify optimizer of have flat master weight # modify optimizer of have flat master weight
self.fp32_groups_flat[ self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.fp32_groups_flat[i]] param_group['params'] = [self.fp32_groups_flat[i]]
# we may have a way of fusing dynamic scale. Do not support for now # we may have a way of fusing dynamic scale. Do not support for now
...@@ -113,16 +111,13 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -113,16 +111,13 @@ class FP16_Optimizer(DeepSpeedOptimizer):
self.mpu = mpu self.mpu = mpu
self.overflow = False self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed)
mpu=self.mpu,
deepspeed=deepspeed)
self.initialize_optimizer_states() self.initialize_optimizer_states()
def initialize_optimizer_states(self): def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups): for i, group in enumerate(self.fp16_groups):
self.fp32_groups_flat[i].grad = torch.zeros( self.fp32_groups_flat[i].grad = torch.zeros(self.fp32_groups_flat[i].size(),
self.fp32_groups_flat[i].size(), device=self.fp32_groups_flat[i].device)
device=self.fp32_groups_flat[i].device)
self.optimizer.step() self.optimizer.step()
...@@ -156,10 +151,7 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -156,10 +151,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
for i, group in enumerate(self.fp16_groups): for i, group in enumerate(self.fp16_groups):
grads_groups_flat.append( grads_groups_flat.append(
_flatten_dense_tensors([ _flatten_dense_tensors([
torch.zeros(p.size(), torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group
dtype=p.dtype,
device=p.device) if p.grad is None else p.grad
for p in group
])) ]))
norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
...@@ -169,17 +161,13 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -169,17 +161,13 @@ class FP16_Optimizer(DeepSpeedOptimizer):
if self.overflow: if self.overflow:
if self.verbose: if self.verbose:
logger.info( logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow return self.overflow
scaled_grad_norm = get_global_norm(norm_list=norm_groups) scaled_grad_norm = get_global_norm(norm_list=norm_groups)
combined_scale = self.unscale_and_clip_grads(grads_groups_flat, combined_scale = self.unscale_and_clip_grads(grads_groups_flat, scaled_grad_norm, apply_scale=False)
scaled_grad_norm,
apply_scale=False)
# Stash unscaled gradient norm # Stash unscaled gradient norm
self._global_grad_norm = scaled_grad_norm / self.cur_scale self._global_grad_norm = scaled_grad_norm / self.cur_scale
...@@ -191,8 +179,7 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -191,8 +179,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
grad_norms=norm_groups) grad_norms=norm_groups)
# TODO: we probably don't need this? just to be safe # TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)): for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params): for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data p.data = q.data
return self.overflow return self.overflow
...@@ -222,9 +209,7 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -222,9 +209,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
def override_loss_scale(self, loss_scale): def override_loss_scale(self, loss_scale):
if loss_scale != self.external_loss_scale: if loss_scale != self.external_loss_scale:
logger.info( logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}'
)
self.custom_loss_scaler = True self.custom_loss_scaler = True
self.external_loss_scale = loss_scale self.external_loss_scale = loss_scale
...@@ -273,10 +258,8 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -273,10 +258,8 @@ class FP16_Optimizer(DeepSpeedOptimizer):
grads_groups_flat.append( grads_groups_flat.append(
_flatten_dense_tensors([ _flatten_dense_tensors([
torch.zeros(p.size(), torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type)
dtype=data_type, for p in group
device=p.device)
if p.grad is None else p.grad.to(data_type) for p in group
])) ]))
for p in group: for p in group:
...@@ -313,8 +296,7 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -313,8 +296,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
self.start_timers([UPDATE_FP16]) self.start_timers([UPDATE_FP16])
for i in range(len(self.fp16_groups)): for i in range(len(self.fp16_groups)):
updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i])
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params): for p, q in zip(self.fp16_groups[i], updated_params):
p.data.copy_(q.data) p.data.copy_(q.data)
...@@ -334,9 +316,7 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -334,9 +316,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
else: else:
pg = groups._get_data_parallel_group() pg = groups._get_data_parallel_group()
scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg)) scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg))
scaled_norm_tensor = torch.tensor(scaled_norm, scaled_norm_tensor = torch.tensor(scaled_norm, device=self.fp32_groups_flat[0].device, dtype=torch.float)
device=self.fp32_groups_flat[0].device,
dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg) dist.all_reduce(scaled_norm_tensor, group=pg)
all_groups_norm = scaled_norm_tensor.item() all_groups_norm = scaled_norm_tensor.item()
#print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
...@@ -376,25 +356,19 @@ class FP16_Optimizer(DeepSpeedOptimizer): ...@@ -376,25 +356,19 @@ class FP16_Optimizer(DeepSpeedOptimizer):
if self.dynamic_loss_scale: if self.dynamic_loss_scale:
prev_scale = self.cur_scale prev_scale = self.cur_scale
if skip: if skip:
self.cur_scale = max(self.cur_scale / self.scale_factor, self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale)
self.min_loss_scale)
self.last_overflow_iter = self.cur_iter self.last_overflow_iter = self.cur_iter
if self.verbose: if self.verbose:
logger.info(f"\nGrad overflow on iteration {self.cur_iter}") logger.info(f"\nGrad overflow on iteration {self.cur_iter}")
logger.info( logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}")
f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}"
)
else: else:
# Ensure self.scale_window updates since last overflow # Ensure self.scale_window updates since last overflow
stable_interval = (self.cur_iter - self.last_overflow_iter) - 1 stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
if (stable_interval > 0) and (stable_interval % self.scale_window == 0): if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
self.cur_scale *= self.scale_factor self.cur_scale *= self.scale_factor
if self.verbose: if self.verbose:
logger.info( logger.info(f"No Grad overflow for {self.scale_window} iterations")
f"No Grad overflow for {self.scale_window} iterations") logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}")
logger.info(
f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}"
)
else: else:
if skip: if skip:
logger.info("Grad overflow on iteration: %s", self.cur_iter) logger.info("Grad overflow on iteration: %s", self.cur_iter)
......
# Copyright 2019 The Microsoft DeepSpeed Team # Copyright (c) Microsoft Corporation.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); # DeepSpeed Team
# you may not use this file except in compliance with the License. """
# You may obtain a copy of the License at Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# http://www.apache.org/licenses/LICENSE-2.0 Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# Unless required by applicable law or agreed to in writing, software You may obtain a copy of the License at
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. http://www.apache.org/licenses/LICENSE-2.0
# See the License for the specific language governing permissions and
# limitations under the License. Unless required by applicable law or agreed to in writing, software
#Taken and modified for DeepSpeed from: distributed under the License is distributed on an "AS IS" BASIS,
# https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9 See the License for the specific language governing permissions and
limitations under the License.
Taken and modified for DeepSpeed from:
https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
"""
import torch import torch
from deepspeed import comm as dist
from deepspeed.utils import logger
INITIAL_LOSS_SCALE = 'init_scale' INITIAL_LOSS_SCALE = 'init_scale'
SCALE_WINDOW = 'scale_window' SCALE_WINDOW = 'scale_window'
...@@ -35,6 +42,7 @@ class LossScalerBase: ...@@ -35,6 +42,7 @@ class LossScalerBase:
"""LossScalarBase """LossScalarBase
Base class for a loss scaler Base class for a loss scaler
""" """
def __init__(self, cur_scale): def __init__(self, cur_scale):
self.cur_scale = cur_scale self.cur_scale = cur_scale
self.dynamic = False self.dynamic = False
...@@ -52,6 +60,7 @@ class LossScalerBase: ...@@ -52,6 +60,7 @@ class LossScalerBase:
def backward(self, loss, retain_graph=False): def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph) scaled_loss.backward(retain_graph=retain_graph)
# print(f'LossScalerBackward: {scaled_loss=}')
class LossScaler(LossScalerBase): class LossScaler(LossScalerBase):
...@@ -65,6 +74,7 @@ class LossScaler(LossScalerBase): ...@@ -65,6 +74,7 @@ class LossScaler(LossScalerBase):
Args: Args:
scale (float, optional, default=1.0): The loss scale. scale (float, optional, default=1.0): The loss scale.
""" """
def __init__(self, scale=1): def __init__(self, scale=1):
super(LossScaler, self).__init__(scale) super(LossScaler, self).__init__(scale)
...@@ -102,6 +112,7 @@ class DynamicLossScaler(LossScalerBase): ...@@ -102,6 +112,7 @@ class DynamicLossScaler(LossScalerBase):
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
""" """
def __init__(self, def __init__(self,
init_scale=2**32, init_scale=2**32,
scale_factor=2., scale_factor=2.,
...@@ -109,7 +120,8 @@ class DynamicLossScaler(LossScalerBase): ...@@ -109,7 +120,8 @@ class DynamicLossScaler(LossScalerBase):
min_scale=1, min_scale=1,
delayed_shift=1, delayed_shift=1,
consecutive_hysteresis=False, consecutive_hysteresis=False,
raise_error_at_min_scale=True): raise_error_at_min_scale=True,
dtype=torch.half):
super(DynamicLossScaler, self).__init__(init_scale) super(DynamicLossScaler, self).__init__(init_scale)
self.cur_iter = 0 self.cur_iter = 0
self.last_overflow_iter = -1 self.last_overflow_iter = -1
...@@ -121,6 +133,7 @@ class DynamicLossScaler(LossScalerBase): ...@@ -121,6 +133,7 @@ class DynamicLossScaler(LossScalerBase):
self.consecutive_hysteresis = consecutive_hysteresis self.consecutive_hysteresis = consecutive_hysteresis
self.raise_error_at_min_scale = raise_error_at_min_scale self.raise_error_at_min_scale = raise_error_at_min_scale
self.dynamic = True self.dynamic = True
self.dtype = dtype
# `params` is a list / generator of torch.Variable # `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params): def has_overflow_serial(self, params):
...@@ -158,10 +171,21 @@ class DynamicLossScaler(LossScalerBase): ...@@ -158,10 +171,21 @@ class DynamicLossScaler(LossScalerBase):
if self.delayed_shift == 1 or self.cur_hysteresis == 1: if self.delayed_shift == 1 or self.cur_hysteresis == 1:
if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale: if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale:
raise Exception( raise Exception(
"Current loss scale already at minimum - cannot decrease scale anymore. Exiting run." "Current loss scale already at minimum - cannot decrease scale anymore. Exiting run.")
) else:
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale) next_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
if dist.get_rank() == 0:
overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
if self.dtype == torch.half:
overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, reducing to {int(next_scale)}"
logger.info(overflow_msg)
self.cur_scale = next_scale
else: else:
if dist.get_rank() == 0:
overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
if self.dtype == torch.half:
overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, but hysteresis is {self.cur_hysteresis}. Reducing hysteresis to {self.cur_hysteresis-1}"
logger.info(overflow_msg)
self.cur_hysteresis -= 1 self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter self.last_overflow_iter = self.cur_iter
else: else:
...@@ -179,8 +203,8 @@ class DynamicLossScaler(LossScalerBase): ...@@ -179,8 +203,8 @@ class DynamicLossScaler(LossScalerBase):
def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args): def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args):
if dtype == torch.half and dynamic_scaling: if dtype == torch.half and dynamic_scaling:
if dynamic_loss_args is None: if dynamic_loss_args is None:
return DynamicLossScaler() return DynamicLossScaler(dtype=dtype)
return DynamicLossScaler(**dynamic_loss_args) return DynamicLossScaler(dtype=dtype, **dynamic_loss_args)
loss_scale_value = static_loss_scale if dtype == torch.half else 1.0 loss_scale_value = static_loss_scale if dtype == torch.half else 1.0
return LossScaler(scale=loss_scale_value) return LossScaler(scale=loss_scale_value)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .adam import OnebitAdam from .adam import OnebitAdam
from .lamb import OnebitLamb from .lamb import OnebitLamb
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import types import types
import torch import torch
import numpy as np import numpy as np
...@@ -39,14 +41,14 @@ class OnebitAdam(torch.optim.Optimizer): ...@@ -39,14 +41,14 @@ class OnebitAdam(torch.optim.Optimizer):
.. _On the Convergence of Adam and Beyond: .. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, def __init__(self,
params, params,
deepspeed=None, deepspeed=None,
lr=1e-3, lr=1e-3,
freeze_step=100000, freeze_step=100000,
bias_correction=True, bias_correction=True,
betas=(0.9, betas=(0.9, 0.999),
0.999),
eps=1e-8, eps=1e-8,
eps_inside_sqrt=False, eps_inside_sqrt=False,
weight_decay=0., weight_decay=0.,
...@@ -89,11 +91,12 @@ class OnebitAdam(torch.optim.Optimizer): ...@@ -89,11 +91,12 @@ class OnebitAdam(torch.optim.Optimizer):
if self.comm_backend_name == 'nccl': if self.comm_backend_name == 'nccl':
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" assert (
(TORCH_MAJOR == 1 and TORCH_MINOR >= 8) or TORCH_MAJOR >= 2
), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend." assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend from deepspeed.runtime.comm.nccl import NcclBackend
self.using_pipeline = hasattr(self.deepspeed, self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
'pipeline_enable_backward_allreduce')
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)
elif self.comm_backend_name == 'mpi': elif self.comm_backend_name == 'mpi':
...@@ -164,22 +167,17 @@ class OnebitAdam(torch.optim.Optimizer): ...@@ -164,22 +167,17 @@ class OnebitAdam(torch.optim.Optimizer):
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p.data)
if not self.initialize or (self.adam_freeze_key if not self.initialize or (self.adam_freeze_key and 'worker_error' not in state.keys()):
and 'worker_error' not in state.keys()):
state['tensor_size'] = torch.numel(p.data) state['tensor_size'] = torch.numel(p.data)
state['corrected_tensor_size'] = state['tensor_size'] state['corrected_tensor_size'] = state['tensor_size']
if state['tensor_size'] % (self.size * self.divider) != 0: if state['tensor_size'] % (self.size * self.divider) != 0:
state['corrected_tensor_size'] += ((self.size * self.divider) - state['corrected_tensor_size'] += ((self.size * self.divider) - (state['tensor_size'] %
(state['tensor_size'] % (self.size * self.divider)))
(self.size * self.divider))) state['server_chunk_size'] = state['corrected_tensor_size'] // self.size
state['server_chunk_size'] = state[
'corrected_tensor_size'] // self.size
get_accelerator().empty_cache() get_accelerator().empty_cache()
state['worker_error'] = torch.zeros(state['corrected_tensor_size'], state['worker_error'] = torch.zeros(state['corrected_tensor_size'], device=p.device)
device=p.device) state['server_error'] = torch.zeros(state['server_chunk_size'], device=p.device)
state['server_error'] = torch.zeros(state['server_chunk_size'],
device=p.device)
get_accelerator().empty_cache() get_accelerator().empty_cache()
self.adam_freeze_key = True self.adam_freeze_key = True
if not self.initialize and dist.get_rank() == 0: if not self.initialize and dist.get_rank() == 0:
...@@ -211,11 +209,9 @@ class OnebitAdam(torch.optim.Optimizer): ...@@ -211,11 +209,9 @@ class OnebitAdam(torch.optim.Optimizer):
if self.size > 1: if self.size > 1:
exp_avg.set_( exp_avg.set_(
self.comm_backend_handle.compressed_allreduce( self.comm_backend_handle.compressed_allreduce(exp_avg, state['worker_error'],
exp_avg, state['server_error'],
state['worker_error'], self.deepspeed.local_rank))
state['server_error'],
self.deepspeed.local_rank))
# Because 1-bit compression cannot represent exact zero, it is required to # Because 1-bit compression cannot represent exact zero, it is required to
# provide a momentum mask for those params that have constant exact zeros in their # provide a momentum mask for those params that have constant exact zeros in their
# momentums, otherwise the compression error would keep accumulating. # momentums, otherwise the compression error would keep accumulating.
...@@ -225,8 +221,7 @@ class OnebitAdam(torch.optim.Optimizer): ...@@ -225,8 +221,7 @@ class OnebitAdam(torch.optim.Optimizer):
# (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.) # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
if 'exp_avg_mask' in group: if 'exp_avg_mask' in group:
if exp_avg.device != group['exp_avg_mask'].device: if exp_avg.device != group['exp_avg_mask'].device:
group['exp_avg_mask'] = group['exp_avg_mask'].to( group['exp_avg_mask'] = group['exp_avg_mask'].to(device=exp_avg.device)
device=exp_avg.device)
exp_avg.mul_(group['exp_avg_mask']) exp_avg.mul_(group['exp_avg_mask'])
if self.initialize: if self.initialize:
...@@ -272,8 +267,7 @@ class OnebitAdam(torch.optim.Optimizer): ...@@ -272,8 +267,7 @@ class OnebitAdam(torch.optim.Optimizer):
for i, group in enumerate(self.param_groups): for i, group in enumerate(self.param_groups):
if 'exp_avg_mask' in group: if 'exp_avg_mask' in group:
state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask'] state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask']
elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[ elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict['param_groups'][i]:
'param_groups'][i]:
state_dict['param_groups'][i].pop('exp_avg_mask') state_dict['param_groups'][i].pop('exp_avg_mask')
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step: if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step:
...@@ -287,9 +281,7 @@ class OnebitAdam(torch.optim.Optimizer): ...@@ -287,9 +281,7 @@ class OnebitAdam(torch.optim.Optimizer):
self.deepspeed.enable_backward_allreduce = True self.deepspeed.enable_backward_allreduce = True
else: else:
if dist.get_rank() == 0: if dist.get_rank() == 0:
print( print("Checkpoint loaded and OnebitAdam compression stage starts/continues.")
"Checkpoint loaded and OnebitAdam compression stage starts/continues."
)
if self.adam_freeze_key is False: if self.adam_freeze_key is False:
self.adam_freeze_key = True self.adam_freeze_key = True
if self.using_pipeline: if self.using_pipeline:
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2021 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import types import types
import torch import torch
import numpy as np import numpy as np
...@@ -54,14 +56,14 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -54,14 +56,14 @@ class OnebitLamb(torch.optim.Optimizer):
.. _On the Convergence of Adam and Beyond: .. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, def __init__(self,
params, params,
deepspeed=None, deepspeed=None,
lr=1e-3, lr=1e-3,
freeze_step=100000, freeze_step=100000,
bias_correction=True, bias_correction=True,
betas=(0.9, betas=(0.9, 0.999),
0.999),
eps=1e-8, eps=1e-8,
eps_inside_sqrt=False, eps_inside_sqrt=False,
weight_decay=0., weight_decay=0.,
...@@ -111,11 +113,12 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -111,11 +113,12 @@ class OnebitLamb(torch.optim.Optimizer):
if self.comm_backend_name == 'nccl': if self.comm_backend_name == 'nccl':
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" assert (
(TORCH_MAJOR == 1 and TORCH_MINOR >= 8) or TORCH_MAJOR >= 2
), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend." assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend from deepspeed.runtime.comm.nccl import NcclBackend
self.using_pipeline = hasattr(self.deepspeed, self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
'pipeline_enable_backward_allreduce')
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)
elif self.comm_backend_name == 'mpi': elif self.comm_backend_name == 'mpi':
...@@ -165,24 +168,20 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -165,24 +168,20 @@ class OnebitLamb(torch.optim.Optimizer):
if self.lamb_freeze_key: if self.lamb_freeze_key:
exp_avg_last_step = [] exp_avg_last_step = []
for group in self.param_groups: for group in self.param_groups:
exp_avg_last_step.append( exp_avg_last_step.append([self.state[p]['exp_avg'].detach().clone() for p in group['params']])
[self.state[p]['exp_avg'].detach().clone() for p in group['params']])
if 'scaling_coeff' not in self.state[self.param_groups[0]['params'][0]]: if 'scaling_coeff' not in self.state[self.param_groups[0]['params'][0]]:
# Compute the scaling_coeff for each momentum at the end of warmup stage. # Compute the scaling_coeff for each momentum at the end of warmup stage.
# This is used to reduce compression error during compression stage. # This is used to reduce compression error during compression stage.
momentum_scales = [] momentum_scales = []
for group in self.param_groups: for group in self.param_groups:
momentum_scales.append([ momentum_scales.append([
(torch.norm(self.state[p]['exp_avg']) / (torch.norm(self.state[p]['exp_avg']) / np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
for p in group['params'] for p in group['params']
]) ])
united_scale = sum([sum(x) for x in momentum_scales]) / sum( united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])
[len(x) for x in momentum_scales])
for i, group in enumerate(self.param_groups): for i, group in enumerate(self.param_groups):
for j, p in enumerate(group['params']): for j, p in enumerate(group['params']):
self.state[p][ self.state[p]['scaling_coeff'] = united_scale / momentum_scales[i][j]
'scaling_coeff'] = united_scale / momentum_scales[i][j]
for group, grads_this_group in zip(self.param_groups, grads_group): for group, grads_this_group in zip(self.param_groups, grads_group):
if grads_this_group is None: if grads_this_group is None:
...@@ -201,8 +200,7 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -201,8 +200,7 @@ class OnebitLamb(torch.optim.Optimizer):
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0 or (len(state) == 1 if len(state) == 0 or (len(state) == 1 and 'scaling_coeff' in state.keys()):
and 'scaling_coeff' in state.keys()):
state['step'] = 0 state['step'] = 0
state['lamb_coeff_freeze'] = 0.0 state['lamb_coeff_freeze'] = 0.0
state['last_factor'] = 1.0 state['last_factor'] = 1.0
...@@ -215,7 +213,8 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -215,7 +213,8 @@ class OnebitLamb(torch.optim.Optimizer):
if not self.initialize: if not self.initialize:
self.lamb_freeze_key = True self.lamb_freeze_key = True
exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_sq_fresh'] exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state[
'exp_avg_sq_fresh']
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
max_coeff = group['max_coeff'] max_coeff = group['max_coeff']
min_coeff = group['min_coeff'] min_coeff = group['min_coeff']
...@@ -243,8 +242,8 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -243,8 +242,8 @@ class OnebitLamb(torch.optim.Optimizer):
if lamb_coeff < min_coeff: if lamb_coeff < min_coeff:
lamb_coeff = min_coeff lamb_coeff = min_coeff
if lamb_coeff != 1.0: if lamb_coeff != 1.0:
state['lamb_coeff_freeze'] = self.coeff_beta * state[ state['lamb_coeff_freeze'] = self.coeff_beta * state['lamb_coeff_freeze'] + (
'lamb_coeff_freeze'] + (1 - self.coeff_beta) * lamb_coeff 1 - self.coeff_beta) * lamb_coeff
self.lamb_coeffs.append(lamb_coeff) self.lamb_coeffs.append(lamb_coeff)
with torch.no_grad(): with torch.no_grad():
p.add_(-group['lr'] * lamb_coeff * update) p.add_(-group['lr'] * lamb_coeff * update)
...@@ -266,20 +265,15 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -266,20 +265,15 @@ class OnebitLamb(torch.optim.Optimizer):
tensor_size += torch.numel(p.data) tensor_size += torch.numel(p.data)
corrected_tensor_size = tensor_size corrected_tensor_size = tensor_size
if tensor_size % (self.size * self.divider) != 0: if tensor_size % (self.size * self.divider) != 0:
difference = ((self.size * self.divider) - (tensor_size % difference = ((self.size * self.divider) - (tensor_size % (self.size * self.divider)))
(self.size * self.divider)))
corrected_tensor_size += difference corrected_tensor_size += difference
self.dummy_exp_avg[0] = torch.zeros( self.dummy_exp_avg[0] = torch.zeros(difference, device=momentum_groups[0].data.device)
difference,
device=momentum_groups[0].data.device)
momentum_groups.append(self.dummy_exp_avg[0]) momentum_groups.append(self.dummy_exp_avg[0])
self.corrected_tensor_sizes.append(corrected_tensor_size) self.corrected_tensor_sizes.append(corrected_tensor_size)
self.server_chunk_sizes.append(corrected_tensor_size // self.size) self.server_chunk_sizes.append(corrected_tensor_size // self.size)
self.exp_avg_flat.append( self.exp_avg_flat.append(_flatten_dense_tensors([p.detach().clone() for p in momentum_groups]))
_flatten_dense_tensors([p.detach().clone() for p in momentum_groups])) updated_params = _unflatten_dense_tensors(self.exp_avg_flat[0], momentum_groups)
updated_params = _unflatten_dense_tensors(self.exp_avg_flat[0],
momentum_groups)
for p, q in zip(momentum_groups, updated_params): for p, q in zip(momentum_groups, updated_params):
p.data = q.data p.data = q.data
...@@ -287,11 +281,8 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -287,11 +281,8 @@ class OnebitLamb(torch.optim.Optimizer):
get_accelerator().empty_cache() get_accelerator().empty_cache()
for i in range(len(self.exp_avg_flat)): for i in range(len(self.exp_avg_flat)):
self.worker_errors.append( self.worker_errors.append(
torch.zeros(self.corrected_tensor_sizes[i], torch.zeros(self.corrected_tensor_sizes[i], device=self.exp_avg_flat[i].device))
device=self.exp_avg_flat[i].device)) self.server_errors.append(torch.zeros(self.server_chunk_sizes[i], device=self.exp_avg_flat[i].device))
self.server_errors.append(
torch.zeros(self.server_chunk_sizes[i],
device=self.exp_avg_flat[i].device))
get_accelerator().empty_cache() get_accelerator().empty_cache()
if self.lamb_freeze_key: if self.lamb_freeze_key:
...@@ -300,31 +291,23 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -300,31 +291,23 @@ class OnebitLamb(torch.optim.Optimizer):
if not self.initialize: if not self.initialize:
get_accelerator().empty_cache() get_accelerator().empty_cache()
self.worker_errors.append( self.worker_errors.append(
torch.zeros(self.corrected_tensor_sizes[i], torch.zeros(self.corrected_tensor_sizes[i], device=self.exp_avg_flat[i].device))
device=self.exp_avg_flat[i].device))
self.server_errors.append( self.server_errors.append(
torch.zeros(self.server_chunk_sizes[i], torch.zeros(self.server_chunk_sizes[i], device=self.exp_avg_flat[i].device))
device=self.exp_avg_flat[i].device))
get_accelerator().empty_cache() get_accelerator().empty_cache()
if dist.get_rank() == 0: if dist.get_rank() == 0:
print("Cupy Buffers Initialized Successfully.") print("Cupy Buffers Initialized Successfully.")
self.comm_backend_handle.compressed_allreduce( self.comm_backend_handle.compressed_allreduce(self.exp_avg_flat[i], self.worker_errors[0],
self.exp_avg_flat[i], self.server_errors[0], self.deepspeed.local_rank)
self.worker_errors[0],
self.server_errors[0],
self.deepspeed.local_rank)
if dist.get_rank() == 0: if dist.get_rank() == 0:
print('Pop out errors', flush=True) print('Pop out errors', flush=True)
del self.worker_errors[:] del self.worker_errors[:]
del self.server_errors[:] del self.server_errors[:]
else: else:
self.comm_backend_handle.compressed_allreduce( self.comm_backend_handle.compressed_allreduce(self.exp_avg_flat[i], self.worker_errors[i],
self.exp_avg_flat[i], self.server_errors[i], self.deepspeed.local_rank)
self.worker_errors[i],
self.server_errors[i],
self.deepspeed.local_rank)
if self.lamb_freeze_key and self.initialize: if self.lamb_freeze_key and self.initialize:
for i, group in enumerate(self.param_groups): for i, group in enumerate(self.param_groups):
...@@ -332,7 +315,8 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -332,7 +315,8 @@ class OnebitLamb(torch.optim.Optimizer):
for j, p in enumerate(group['params']): for j, p in enumerate(group['params']):
state = self.state[p] state = self.state[p]
exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_sq_fresh'] exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state[
'exp_avg_sq_fresh']
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
exp_avg.div_(self.state[p]['scaling_coeff']) exp_avg.div_(self.state[p]['scaling_coeff'])
# Because 1-bit compression cannot represent exact zero, it is required to # Because 1-bit compression cannot represent exact zero, it is required to
...@@ -345,15 +329,11 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -345,15 +329,11 @@ class OnebitLamb(torch.optim.Optimizer):
# to add this exp_avg_mask for BERT pre-training.) # to add this exp_avg_mask for BERT pre-training.)
if 'exp_avg_mask' in group: if 'exp_avg_mask' in group:
if exp_avg.device != group['exp_avg_mask'].device: if exp_avg.device != group['exp_avg_mask'].device:
group['exp_avg_mask'] = group['exp_avg_mask'].to( group['exp_avg_mask'] = group['exp_avg_mask'].to(device=exp_avg.device)
device=exp_avg.device)
exp_avg.mul_(group['exp_avg_mask']) exp_avg.mul_(group['exp_avg_mask'])
grad_reconstruct = ((exp_avg - exp_avg_last_step[i][j] * beta1) / grad_reconstruct = ((exp_avg - exp_avg_last_step[i][j] * beta1) / (1 - beta1))
(1 - beta1)) exp_avg_sq_fresh.mul_(beta2).addcmul_(1 - beta2, grad_reconstruct, grad_reconstruct)
exp_avg_sq_fresh.mul_(beta2).addcmul_(1 - beta2,
grad_reconstruct,
grad_reconstruct)
denom = exp_avg_sq.sqrt() + group['eps'] denom = exp_avg_sq.sqrt() + group['eps']
update_prelim = exp_avg / denom update_prelim = exp_avg / denom
...@@ -367,9 +347,7 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -367,9 +347,7 @@ class OnebitLamb(torch.optim.Optimizer):
denom_real = exp_avg_sq_fresh.sqrt() + group['eps'] denom_real = exp_avg_sq_fresh.sqrt() + group['eps']
factor = (denom / denom_real).max().item() factor = (denom / denom_real).max().item()
if group['weight_decay'] > 0.0: if group['weight_decay'] > 0.0:
update_ratio = min(1.0, update_ratio = min(1.0, (update_prelim.pow(2).sum().sqrt() / update_norm).item())
(update_prelim.pow(2).sum().sqrt() /
update_norm).item())
factor = factor * update_ratio + (1.0 - update_ratio) factor = factor * update_ratio + (1.0 - update_ratio)
if factor > self.factor_max: if factor > self.factor_max:
factor = self.factor_max factor = self.factor_max
...@@ -416,8 +394,7 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -416,8 +394,7 @@ class OnebitLamb(torch.optim.Optimizer):
for i, group in enumerate(self.param_groups): for i, group in enumerate(self.param_groups):
if 'exp_avg_mask' in group: if 'exp_avg_mask' in group:
state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask'] state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask']
elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[ elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict['param_groups'][i]:
'param_groups'][i]:
state_dict['param_groups'][i].pop('exp_avg_mask') state_dict['param_groups'][i].pop('exp_avg_mask')
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
# need to reset the fused momentum since loading states will break the linking # need to reset the fused momentum since loading states will break the linking
...@@ -442,9 +419,7 @@ class OnebitLamb(torch.optim.Optimizer): ...@@ -442,9 +419,7 @@ class OnebitLamb(torch.optim.Optimizer):
self.state[p].pop('scaling_coeff') self.state[p].pop('scaling_coeff')
else: else:
if dist.get_rank() == 0: if dist.get_rank() == 0:
print( print("Checkpoint loaded and OnebitLamb compression stage starts/continues.")
"Checkpoint loaded and OnebitLamb compression stage starts/continues."
)
if self.lamb_freeze_key is False: if self.lamb_freeze_key is False:
self.lamb_freeze_key = True self.lamb_freeze_key = True
if self.using_pipeline: if self.using_pipeline:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment