Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os import os
import sys import sys
...@@ -13,6 +16,7 @@ from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE ...@@ -13,6 +16,7 @@ from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE
class MultiNodeRunner(ABC): class MultiNodeRunner(ABC):
def __init__(self, args, world_info_base64): def __init__(self, args, world_info_base64):
self.args = args self.args = args
self.validate_args() self.validate_args()
...@@ -45,6 +49,7 @@ class MultiNodeRunner(ABC): ...@@ -45,6 +49,7 @@ class MultiNodeRunner(ABC):
class PDSHRunner(MultiNodeRunner): class PDSHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64): def __init__(self, args, world_info_base64):
super().__init__(args, world_info_base64) super().__init__(args, world_info_base64)
...@@ -56,9 +61,7 @@ class PDSHRunner(MultiNodeRunner): ...@@ -56,9 +61,7 @@ class PDSHRunner(MultiNodeRunner):
return "pdsh" return "pdsh"
def parse_user_args(self): def parse_user_args(self):
return list( return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args))
map(lambda x: x if x.startswith("-") else f"'{x}'",
self.args.user_args))
def get_cmd(self, environment, active_resources): def get_cmd(self, environment, active_resources):
environment['PDSH_RCMD_TYPE'] = 'ssh' environment['PDSH_RCMD_TYPE'] = 'ssh'
...@@ -68,14 +71,8 @@ class PDSHRunner(MultiNodeRunner): ...@@ -68,14 +71,8 @@ class PDSHRunner(MultiNodeRunner):
# PDSH flags for max node fan out and specific hosts to launch on # PDSH flags for max node fan out and specific hosts to launch on
# See https://linux.die.net/man/1/pdsh for flag details # See https://linux.die.net/man/1/pdsh for flag details
pdsh_cmd_args = [ pdsh_cmd_args = ['pdsh', '-S', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers] + split(
'pdsh', self.args.launcher_args)
'-S',
'-f',
str(PDSH_MAX_FAN_OUT),
'-w',
active_workers
] + split(self.args.launcher_args)
exports = "" exports = ""
for key, val in self.exports.items(): for key, val in self.exports.items():
...@@ -84,15 +81,8 @@ class PDSHRunner(MultiNodeRunner): ...@@ -84,15 +81,8 @@ class PDSHRunner(MultiNodeRunner):
# https://linux.die.net/man/1/pdsh # https://linux.die.net/man/1/pdsh
# %n will be replaced by pdsh command # %n will be replaced by pdsh command
deepspeed_launch = [ deepspeed_launch = [
exports, exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "deepspeed.launcher.launch",
f"cd {os.path.abspath('.')};", f'--world_info={self.world_info_base64}', "--node_rank=%n", f"--master_addr={self.args.master_addr}",
sys.executable,
"-u",
"-m",
"deepspeed.launcher.launch",
f'--world_info={self.world_info_base64}',
"--node_rank=%n",
f"--master_addr={self.args.master_addr}",
f"--master_port={self.args.master_port}" f"--master_port={self.args.master_port}"
] ]
if self.args.no_python: if self.args.no_python:
...@@ -111,11 +101,11 @@ class PDSHRunner(MultiNodeRunner): ...@@ -111,11 +101,11 @@ class PDSHRunner(MultiNodeRunner):
cmd_to_search = [i + "\\" for i in deepspeed_launch[2:6]] cmd_to_search = [i + "\\" for i in deepspeed_launch[2:6]]
kill_command = pdsh_cmd_args + ["pkill -f ", " ".join(cmd_to_search)[:-2]] kill_command = pdsh_cmd_args + ["pkill -f ", " ".join(cmd_to_search)[:-2]]
return pdsh_cmd_args + deepspeed_launch + [self.user_script return pdsh_cmd_args + deepspeed_launch + [self.user_script] + self.user_arguments, kill_command
] + self.user_arguments, kill_command
class OpenMPIRunner(MultiNodeRunner): class OpenMPIRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool): def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64) super().__init__(args, world_info_base64)
self.resource_pool = resource_pool self.resource_pool = resource_pool
...@@ -133,11 +123,9 @@ class OpenMPIRunner(MultiNodeRunner): ...@@ -133,11 +123,9 @@ class OpenMPIRunner(MultiNodeRunner):
super().validate_args() super().validate_args()
#TODO: Allow for include/exclude at node-level but not gpu-level #TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "": if self.args.include != "" or self.args.exclude != "":
raise ValueError( raise ValueError(f"{self.name} backend does not support worker include/exclusion")
f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1: if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError( raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
f"{self.name} backend does not support limiting num nodes/gpus")
def get_cmd(self, environment, active_resources): def get_cmd(self, environment, active_resources):
total_process_count = sum(self.resource_pool.values()) total_process_count = sum(self.resource_pool.values())
...@@ -166,11 +154,11 @@ class OpenMPIRunner(MultiNodeRunner): ...@@ -166,11 +154,11 @@ class OpenMPIRunner(MultiNodeRunner):
if self.args.module: if self.args.module:
python_exec.append("-m") python_exec.append("-m")
return mpirun_cmd + export_cmd + python_exec + [self.user_script return mpirun_cmd + export_cmd + python_exec + [self.user_script] + self.user_arguments
] + self.user_arguments
class MPICHRunner(MultiNodeRunner): class MPICHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool): def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64) super().__init__(args, world_info_base64)
self.resource_pool = resource_pool self.resource_pool = resource_pool
...@@ -187,17 +175,22 @@ class MPICHRunner(MultiNodeRunner): ...@@ -187,17 +175,22 @@ class MPICHRunner(MultiNodeRunner):
super().validate_args() super().validate_args()
#TODO: Allow for include/exclude at node-level but not gpu-level #TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "": if self.args.include != "" or self.args.exclude != "":
raise ValueError( raise ValueError(f"{self.name} backend does not support worker include/exclusion")
f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1: if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError( raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
f"{self.name} backend does not support limiting num nodes/gpus")
def get_cmd(self, environment, active_resources): def get_cmd(self, environment, active_resources):
devices_per_node = self.resource_pool.values() devices_per_node = self.resource_pool.values()
total_process_count = sum(devices_per_node) total_process_count = sum(devices_per_node)
process_per_node = list(devices_per_node)[0] process_per_node = list(devices_per_node)[0]
hosts = ""
for i, host in enumerate(self.resource_pool.keys()):
if i == 0:
hosts = f"{host}"
else:
hosts += f",{host}"
mpirun_cmd = [ mpirun_cmd = [
'mpirun', 'mpirun',
...@@ -205,21 +198,24 @@ class MPICHRunner(MultiNodeRunner): ...@@ -205,21 +198,24 @@ class MPICHRunner(MultiNodeRunner):
f'{total_process_count}', f'{total_process_count}',
'-ppn', '-ppn',
f'{process_per_node}', f'{process_per_node}',
'-hosts',
f'{hosts}',
] + split(self.args.launcher_args) ] + split(self.args.launcher_args)
export_cmd = [] export_cmd = []
for k, v in self.exports.items(): for k, v in self.exports.items():
export_cmd += ['-x', "{}={}".format(k, v)] export_cmd += ['-genv', "{}={}".format(k, v)]
python_exec = [] python_exec = []
if not self.args.no_python: if not self.args.no_python:
python_exec = [sys.executable, "-u"] python_exec = [sys.executable, "-u"]
if self.args.module: if self.args.module:
python_exec.append("-m") python_exec.append("-m")
return mpirun_cmd + python_exec + [self.user_script] + self.user_arguments return mpirun_cmd + export_cmd + python_exec + [self.user_script] + self.user_arguments
class SlurmRunner(MultiNodeRunner): class SlurmRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool): def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64) super().__init__(args, world_info_base64)
self.resource_pool = resource_pool self.resource_pool = resource_pool
...@@ -232,7 +228,8 @@ class SlurmRunner(MultiNodeRunner): ...@@ -232,7 +228,8 @@ class SlurmRunner(MultiNodeRunner):
return 'slurm' return 'slurm'
def get_cmd(self, environment, active_resources): def get_cmd(self, environment, active_resources):
assert not getattr(self.args, 'detect_nvlink_pairs', False), "slurm backend does not support remapping visible devices" assert not getattr(self.args, 'detect_nvlink_pairs',
False), "slurm backend does not support remapping visible devices"
total_process_count = sum(self.resource_pool.values()) total_process_count = sum(self.resource_pool.values())
srun_cmd = [ srun_cmd = [
'srun', 'srun',
...@@ -261,12 +258,12 @@ class SlurmRunner(MultiNodeRunner): ...@@ -261,12 +258,12 @@ class SlurmRunner(MultiNodeRunner):
exports += f",{key}={val}" exports += f",{key}={val}"
python_exec = [sys.executable, "-u"] python_exec = [sys.executable, "-u"]
command = srun_cmd + [exports] + python_exec + [self.user_script command = srun_cmd + [exports] + python_exec + [self.user_script] + self.user_arguments
] + self.user_arguments
return command return command
class MVAPICHRunner(MultiNodeRunner): class MVAPICHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool): def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64) super().__init__(args, world_info_base64)
self.resource_pool = resource_pool self.resource_pool = resource_pool
...@@ -303,9 +300,7 @@ class MVAPICHRunner(MultiNodeRunner): ...@@ -303,9 +300,7 @@ class MVAPICHRunner(MultiNodeRunner):
if "MVAPICH2-GDR" in mpiname_results: if "MVAPICH2-GDR" in mpiname_results:
exists = True exists = True
else: else:
warnings.warn( warnings.warn(f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}")
f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}"
)
return exists return exists
@property @property
...@@ -316,11 +311,9 @@ class MVAPICHRunner(MultiNodeRunner): ...@@ -316,11 +311,9 @@ class MVAPICHRunner(MultiNodeRunner):
super().validate_args() super().validate_args()
#TODO: Allow for include/exclude at node-level but not gpu-level #TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "": if self.args.include != "" or self.args.exclude != "":
raise ValueError( raise ValueError(f"{self.name} backend does not support worker include/exclusion")
f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1: if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError( raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
f"{self.name} backend does not support limiting num nodes/gpus")
def get_cmd(self, environment, active_resources): def get_cmd(self, environment, active_resources):
devices_per_node = self.resource_pool.values() devices_per_node = self.resource_pool.values()
...@@ -353,5 +346,4 @@ class MVAPICHRunner(MultiNodeRunner): ...@@ -353,5 +346,4 @@ class MVAPICHRunner(MultiNodeRunner):
if self.args.module: if self.args.module:
python_exec.append("-m") python_exec.append("-m")
return mpirun_cmd + export_cmd + python_exec + [self.user_script return mpirun_cmd + export_cmd + python_exec + [self.user_script] + self.user_arguments
] + self.user_arguments
# Copyright 2020 The Microsoft DeepSpeed Team # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
""" """
DeepSpeed runner is the main front-end to launching multi-worker DeepSpeed runner is the main front-end to launching multi-worker
training jobs with DeepSpeed. By default this uses pdsh to parallel training jobs with DeepSpeed. By default this uses pdsh to parallel
...@@ -36,9 +39,9 @@ PDSH_MAX_FAN_OUT = 1024 ...@@ -36,9 +39,9 @@ PDSH_MAX_FAN_OUT = 1024
def parse_args(args=None): def parse_args(args=None):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="DeepSpeed runner to help launch distributed "
description="DeepSpeed runner to help launch distributed " "multi-node/multi-gpu training jobs.",
"multi-node/multi-gpu training jobs.") formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-H", parser.add_argument("-H",
"--hostfile", "--hostfile",
...@@ -109,12 +112,11 @@ def parse_args(args=None): ...@@ -109,12 +112,11 @@ def parse_args(args=None):
help="(optional) IP address of node 0, will be " help="(optional) IP address of node 0, will be "
"inferred via 'hostname -I' if not specified.") "inferred via 'hostname -I' if not specified.")
parser.add_argument( parser.add_argument("--launcher",
"--launcher", default=PDSH_LAUNCHER,
default=PDSH_LAUNCHER, type=str,
type=str, help="(optional) choose launcher backend for multi-node "
help="(optional) choose launcher backend for multi-node " "training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH.")
"training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH.")
parser.add_argument("--launcher_args", parser.add_argument("--launcher_args",
default="", default="",
...@@ -147,37 +149,40 @@ def parse_args(args=None): ...@@ -147,37 +149,40 @@ def parse_args(args=None):
help="Force multi-node launcher mode, helps in cases where user " help="Force multi-node launcher mode, helps in cases where user "
"wants to launch on single remote node.") "wants to launch on single remote node.")
parser.add_argument( parser.add_argument("--save_pid",
"--save_pid", action="store_true",
action="store_true", help="Save file containing launcher process id (pid) at /tmp/<main-pid>.ds, "
help="Save file containing launcher process id (pid) at /tmp/<main-pid>.ds, " "where <main-pid> is the pid of the first process that invoked `deepspeed`. "
"where <main-pid> is the pid of the first process that invoked `deepspeed`. " "Useful when launching deepspeed processes programmatically.")
"Useful when launching deepspeed processes programmatically.")
parser.add_argument("--enable_each_rank_log",
parser.add_argument( default="None",
"--enable_each_rank_log", type=str,
default="None", help="redirect the stdout and stderr from each rank into different log files")
type=str,
help="redirect the stdout and stderr from each rank into different log files") parser.add_argument("--autotuning",
default="",
parser.add_argument( choices=["tune", "run"],
"--autotuning", type=str,
default="", help="Run DeepSpeed autotuner to discover optimal configuration parameters "
choices=["tune", "before running job.")
"run"],
type=str,
help="Run DeepSpeed autotuner to discover optimal configuration parameters "
"before running job.")
parser.add_argument("--elastic_training", parser.add_argument("--elastic_training",
action="store_true", action="store_true",
help="Enable elastic training support in DeepSpeed.") help="Enable elastic training support in DeepSpeed.")
parser.add_argument("user_script", parser.add_argument("user_script", type=str, help="User script to launch, followed by any required "
type=str,
help="User script to launch, followed by any required "
"arguments.") "arguments.")
parser.add_argument('user_args', nargs=argparse.REMAINDER) parser.add_argument('user_args', nargs=argparse.REMAINDER)
parser.add_argument("--bind_cores_to_rank",
action="store_true",
help="Bind each rank to different cores of the host")
parser.add_argument("--bind_core_list",
type=str,
default=None,
help="List of cores to bind to with comma separated list of "
"numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not "
"specified, all cores on system would be used rank binding")
return parser.parse_args(args=args) return parser.parse_args(args=args)
...@@ -213,21 +218,15 @@ def _parse_hostfile(hostfile_lines): ...@@ -213,21 +218,15 @@ def _parse_hostfile(hostfile_lines):
num_slots = int(match.group(2)) num_slots = int(match.group(2))
if host in resource_pool: if host in resource_pool:
logger.error(f"Bad hostfile text: {hostfile_lines}") logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError( raise ValueError(f"Hostfile contains multiple entries for {host}, unable to proceed with launching")
f"Hostfile contains multiple entries for {host}, unable to proceed with launching"
)
resource_pool[host] = num_slots resource_pool[host] = num_slots
else: else:
logger.error(f"Bad hostfile text: {hostfile_lines}") logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError( raise ValueError("Hostfile contains a bad entry: {line}, unable to proceed with launching")
"Hostfile contains a bad entry: {line}, unable to proceed with launching"
)
if len(resource_pool) == 0: if len(resource_pool) == 0:
logger.error(f"Bad hostfile text: {hostfile_lines}") logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError( raise ValueError("Hostfile is empty or not formatted correctly, unable to proceed with launching.")
"Hostfile is empty or not formatted correctly, unable to proceed with launching."
)
return resource_pool return resource_pool
...@@ -337,9 +336,7 @@ def parse_inclusion_exclusion(resource_pool, inclusion, exclusion): ...@@ -337,9 +336,7 @@ def parse_inclusion_exclusion(resource_pool, inclusion, exclusion):
for hostname, slots in resource_pool.items(): for hostname, slots in resource_pool.items():
active_resources[hostname] = list(range(slots)) active_resources[hostname] = list(range(slots))
return parse_resource_filter(active_resources, return parse_resource_filter(active_resources, include_str=inclusion, exclude_str=exclusion)
include_str=inclusion,
exclude_str=exclusion)
def encode_world_info(world_info): def encode_world_info(world_info):
...@@ -389,8 +386,7 @@ def main(args=None): ...@@ -389,8 +386,7 @@ def main(args=None):
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if not resource_pool and len(cuda_visible_devices): if not resource_pool and len(cuda_visible_devices):
detected_str = f"Detected CUDA_VISIBLE_DEVICES={cuda_visible_devices}" detected_str = f"Detected CUDA_VISIBLE_DEVICES={cuda_visible_devices}"
if len(args.include) or len( if len(args.include) or len(args.exclude) or args.num_nodes > 1 or args.num_gpus > 0:
args.exclude) or args.num_nodes > 1 or args.num_gpus > 0:
print( print(
f"{detected_str} but ignoring it because one or several of --include/--exclude/--num_gpus/--num_nodes cl args were used. If you want to use CUDA_VISIBLE_DEVICES don't pass any of these arguments to deepspeed." f"{detected_str} but ignoring it because one or several of --include/--exclude/--num_gpus/--num_nodes cl args were used. If you want to use CUDA_VISIBLE_DEVICES don't pass any of these arguments to deepspeed."
) )
...@@ -416,20 +412,17 @@ def main(args=None): ...@@ -416,20 +412,17 @@ def main(args=None):
if not multi_node_exec and args.num_nodes > 1: if not multi_node_exec and args.num_nodes > 1:
raise ValueError("Num nodes is >1 but no extra nodes available via hostfile") raise ValueError("Num nodes is >1 but no extra nodes available via hostfile")
active_resources = parse_inclusion_exclusion(resource_pool, active_resources = parse_inclusion_exclusion(resource_pool, args.include, args.exclude)
args.include,
args.exclude)
env = os.environ.copy() env = os.environ.copy()
# validate that passwordless-ssh is workly properly with this hostfile # validate that passwordless-ssh is workly properly with this hostfile
if multi_node_exec and not args.no_ssh_check: if multi_node_exec and not args.no_ssh_check:
first_host = list(active_resources.keys())[0] first_host = list(active_resources.keys())[0]
try: try:
subprocess.check_call( subprocess.check_call(f'ssh -o PasswordAuthentication=no {first_host} hostname',
f'ssh -o PasswordAuthentication=no {first_host} hostname', stderr=subprocess.DEVNULL,
stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
stdout=subprocess.DEVNULL, shell=True)
shell=True)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
raise RuntimeError( raise RuntimeError(
f"Using hostfile at {args.hostfile} but host={first_host} was not reachable via ssh. If you are running with a single node please remove {args.hostfile} or setup passwordless ssh." f"Using hostfile at {args.hostfile} but host={first_host} was not reachable via ssh. If you are running with a single node please remove {args.hostfile} or setup passwordless ssh."
...@@ -481,13 +474,8 @@ def main(args=None): ...@@ -481,13 +474,8 @@ def main(args=None):
if not multi_node_exec: if not multi_node_exec:
deepspeed_launch = [ deepspeed_launch = [
sys.executable, sys.executable, "-u", "-m", "deepspeed.launcher.launch", f"--world_info={world_info_base64}",
"-u", f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
"-m",
"deepspeed.launcher.launch",
f"--world_info={world_info_base64}",
f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"
] ]
if args.no_python: if args.no_python:
deepspeed_launch.append("--no_python") deepspeed_launch.append("--no_python")
...@@ -498,12 +486,15 @@ def main(args=None): ...@@ -498,12 +486,15 @@ def main(args=None):
if args.save_pid: if args.save_pid:
deepspeed_launch += ["--save_pid", f"{os.getpid()}"] deepspeed_launch += ["--save_pid", f"{os.getpid()}"]
if args.enable_each_rank_log: if args.enable_each_rank_log:
deepspeed_launch.append( deepspeed_launch.append(f"--enable_each_rank_log={args.enable_each_rank_log}")
f"--enable_each_rank_log={args.enable_each_rank_log}")
if args.elastic_training: if args.elastic_training:
deepspeed_launch.append("--enable_elastic_training") deepspeed_launch.append("--enable_elastic_training")
deepspeed_launch.append(f"--max_elastic_nodes={args.max_elastic_nodes}") deepspeed_launch.append(f"--max_elastic_nodes={args.max_elastic_nodes}")
deepspeed_launch.append(f"--min_elastic_nodes={args.min_elastic_nodes}") deepspeed_launch.append(f"--min_elastic_nodes={args.min_elastic_nodes}")
if args.bind_cores_to_rank:
deepspeed_launch.append("--bind_cores_to_rank")
if args.bind_core_list != None:
deepspeed_launch.append(f"--bind_core_list={args.bind_core_list}")
cmd = deepspeed_launch + [args.user_script] + args.user_args cmd = deepspeed_launch + [args.user_script] + args.user_args
else: else:
args.launcher = args.launcher.lower() args.launcher = args.launcher.lower()
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .transformers.ds_transformer import DeepSpeedTransformerInference from .transformers.ds_transformer import DeepSpeedTransformerInference
from .transformers.clip_encoder import DSClipEncoder from .transformers.clip_encoder import DSClipEncoder
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
from ..features.cuda_graph import CUDAGraph from ..features.cuda_graph import CUDAGraph
class DSUNet(CUDAGraph, torch.nn.Module): class DSUNet(CUDAGraph, torch.nn.Module):
def __init__(self, unet, enable_cuda_graph=True): def __init__(self, unet, enable_cuda_graph=True):
super().__init__(enable_cuda_graph=enable_cuda_graph) super().__init__(enable_cuda_graph=enable_cuda_graph)
self.unet = unet self.unet = unet
...@@ -59,5 +62,12 @@ class DSUNet(CUDAGraph, torch.nn.Module): ...@@ -59,5 +62,12 @@ class DSUNet(CUDAGraph, torch.nn.Module):
self.cuda_graph_created = True self.cuda_graph_created = True
def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True): def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True, cross_attention_kwargs=None):
return self.unet(sample, timestamp, encoder_hidden_states, return_dict) if cross_attention_kwargs:
return self.unet(sample,
timestamp,
encoder_hidden_states,
return_dict,
cross_attention_kwargs=cross_attention_kwargs)
else:
return self.unet(sample, timestamp, encoder_hidden_states, return_dict)
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
from ..features.cuda_graph import CUDAGraph from ..features.cuda_graph import CUDAGraph
class DSVAE(CUDAGraph, torch.nn.Module): class DSVAE(CUDAGraph, torch.nn.Module):
def __init__(self, vae, enable_cuda_graph=True): def __init__(self, vae, enable_cuda_graph=True):
super().__init__(enable_cuda_graph=enable_cuda_graph) super().__init__(enable_cuda_graph=enable_cuda_graph)
self.vae = vae self.vae = vae
self.config = vae.config
self.device = self.vae.device self.device = self.vae.device
self.dtype = self.vae.dtype self.dtype = self.vae.dtype
self.vae.requires_grad_(requires_grad=False) self.vae.requires_grad_(requires_grad=False)
...@@ -44,8 +48,7 @@ class DSVAE(CUDAGraph, torch.nn.Module): ...@@ -44,8 +48,7 @@ class DSVAE(CUDAGraph, torch.nn.Module):
self.static_decoder_kwargs = kwargs self.static_decoder_kwargs = kwargs
with torch.cuda.graph(self._decoder_cuda_graph): with torch.cuda.graph(self._decoder_cuda_graph):
self.static_decoder_output = self._decode(*self.static_decoder_inputs, self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs)
**self.static_decoder_kwargs)
self.decoder_cuda_graph_created = True self.decoder_cuda_graph_created = True
...@@ -88,8 +91,7 @@ class DSVAE(CUDAGraph, torch.nn.Module): ...@@ -88,8 +91,7 @@ class DSVAE(CUDAGraph, torch.nn.Module):
self.static_encoder_kwargs = kwargs self.static_encoder_kwargs = kwargs
with torch.cuda.graph(self._encoder_cuda_graph): with torch.cuda.graph(self._encoder_cuda_graph):
self.static_encoder_output = self._encode(*self.static_encoder_inputs, self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs)
**self.static_encoder_kwargs)
self.encoder_cuda_graph_created = True self.encoder_cuda_graph_created = True
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
''' # Copyright (c) Microsoft Corporation.
Copyright 2023 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class CUDAGraph(ABC): class CUDAGraph(ABC):
def __init__(self, enable_cuda_graph=False): def __init__(self, enable_cuda_graph=False):
super().__init__() super().__init__()
self.enable_cuda_graph = enable_cuda_graph self.enable_cuda_graph = enable_cuda_graph
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
from deepspeed.accelerator import get_accelerator from deepspeed.accelerator import get_accelerator
from ..features.cuda_graph import CUDAGraph from ..features.cuda_graph import CUDAGraph
class DSClipEncoder(CUDAGraph, torch.nn.Module): class DSClipEncoder(CUDAGraph, torch.nn.Module):
def __init__(self, enc, enable_cuda_graph=False): def __init__(self, enc, enable_cuda_graph=False):
super().__init__(enable_cuda_graph=enable_cuda_graph) super().__init__(enable_cuda_graph=enable_cuda_graph)
enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask
...@@ -22,11 +25,7 @@ class DSClipEncoder(CUDAGraph, torch.nn.Module): ...@@ -22,11 +25,7 @@ class DSClipEncoder(CUDAGraph, torch.nn.Module):
self.config = self.enc.config self.config = self.enc.config
def _build_causal_attention_mask(self, bsz, seq_len, dtype): def _build_causal_attention_mask(self, bsz, seq_len, dtype):
mask = torch.empty(bsz, mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=get_accelerator().current_device_name())
seq_len,
seq_len,
dtype=dtype,
device=get_accelerator().current_device_name())
mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) mask.triu_(1)
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
...@@ -69,9 +68,8 @@ class DSClipEncoder(CUDAGraph, torch.nn.Module): ...@@ -69,9 +68,8 @@ class DSClipEncoder(CUDAGraph, torch.nn.Module):
self.static_kwargs[self.iter] = kwargs self.static_kwargs[self.iter] = kwargs
with torch.cuda.graph(self._cuda_graphs[self.iter]): with torch.cuda.graph(self._cuda_graphs[self.iter]):
self.static_output[self.iter] = self._forward( self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter],
*self.static_inputs[self.iter], **self.static_kwargs[self.iter])
**self.static_kwargs[self.iter])
self.cuda_graph_created[self.iter] = True self.cuda_graph_created[self.iter] = True
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch.nn as nn import torch.nn as nn
class DeepSpeedTransformerBase(nn.module): class DeepSpeedTransformerBase(nn.module):
def __init__(self): def __init__(self):
pass pass
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
...@@ -8,6 +9,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee ...@@ -8,6 +9,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee
class DeepSpeedBERTInference(DeepSpeedTransformerInference): class DeepSpeedBERTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed BERT Transformer Layer. """Initialize the DeepSpeed BERT Transformer Layer.
""" """
def __init__(self, def __init__(self,
config, config,
mp_group=None, mp_group=None,
...@@ -15,9 +17,4 @@ class DeepSpeedBERTInference(DeepSpeedTransformerInference): ...@@ -15,9 +17,4 @@ class DeepSpeedBERTInference(DeepSpeedTransformerInference):
quantize_groups=1, quantize_groups=1,
merge_count=1, merge_count=1,
mlp_extra_grouping=False): mlp_extra_grouping=False):
super().__init__(config, super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
...@@ -8,6 +9,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee ...@@ -8,6 +9,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee
class DeepSpeedBloomInference(DeepSpeedTransformerInference): class DeepSpeedBloomInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed Bloom Transformer Layer. """Initialize the DeepSpeed Bloom Transformer Layer.
""" """
def __init__(self, def __init__(self,
config, config,
mp_group=None, mp_group=None,
...@@ -15,9 +17,4 @@ class DeepSpeedBloomInference(DeepSpeedTransformerInference): ...@@ -15,9 +17,4 @@ class DeepSpeedBloomInference(DeepSpeedTransformerInference):
quantize_groups=1, quantize_groups=1,
merge_count=1, merge_count=1,
mlp_extra_grouping=False): mlp_extra_grouping=False):
super().__init__(config, super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
...@@ -8,6 +9,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee ...@@ -8,6 +9,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee
class DeepSpeedGPTInference(DeepSpeedTransformerInference): class DeepSpeedGPTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed GPT Transformer Layer. """Initialize the DeepSpeed GPT Transformer Layer.
""" """
def __init__(self, def __init__(self,
config, config,
mp_group=None, mp_group=None,
...@@ -15,9 +17,4 @@ class DeepSpeedGPTInference(DeepSpeedTransformerInference): ...@@ -15,9 +17,4 @@ class DeepSpeedGPTInference(DeepSpeedTransformerInference):
quantize_groups=1, quantize_groups=1,
merge_count=1, merge_count=1,
mlp_extra_grouping=False): mlp_extra_grouping=False):
super().__init__(config, super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
...@@ -8,6 +9,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee ...@@ -8,6 +9,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee
class DeepSpeedMegatronGPTInference(DeepSpeedTransformerInference): class DeepSpeedMegatronGPTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed Megatron GPT Transformer Layer. """Initialize the DeepSpeed Megatron GPT Transformer Layer.
""" """
def __init__(self, def __init__(self,
config, config,
mp_group=None, mp_group=None,
...@@ -15,9 +17,4 @@ class DeepSpeedMegatronGPTInference(DeepSpeedTransformerInference): ...@@ -15,9 +17,4 @@ class DeepSpeedMegatronGPTInference(DeepSpeedTransformerInference):
quantize_groups=1, quantize_groups=1,
merge_count=1, merge_count=1,
mlp_extra_grouping=False): mlp_extra_grouping=False):
super().__init__(config, super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
...@@ -8,6 +9,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee ...@@ -8,6 +9,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee
class DeepSpeedOPTInference(DeepSpeedTransformerInference): class DeepSpeedOPTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed OPT Transformer Layer. """Initialize the DeepSpeed OPT Transformer Layer.
""" """
def __init__(self, def __init__(self,
config, config,
mp_group=None, mp_group=None,
...@@ -15,9 +17,4 @@ class DeepSpeedOPTInference(DeepSpeedTransformerInference): ...@@ -15,9 +17,4 @@ class DeepSpeedOPTInference(DeepSpeedTransformerInference):
quantize_groups=1, quantize_groups=1,
merge_count=1, merge_count=1,
mlp_extra_grouping=False): mlp_extra_grouping=False):
super().__init__(config, super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -56,37 +57,26 @@ class DeepSpeedTransformerInference(nn.Module): ...@@ -56,37 +57,26 @@ class DeepSpeedTransformerInference(nn.Module):
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0]) log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
if self.config.bigscience_bloom: if self.config.bigscience_bloom:
self.attention = BloomSelfAttention(self.config, self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
mp_group,
quantize_scales,
quantize_groups,
merge_count)
else: else:
self.attention = DeepSpeedSelfAttention(self.config, self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
mp_group,
quantize_scales,
quantize_groups,
merge_count) merge_count)
self.mlp = DeepSpeedMLP(self.config, self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping) mlp_extra_grouping)
device = get_accelerator().current_device_name( device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
) # if config.bigscience_bloom else 'cpu' if self.config.set_empty_params:
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, self.norm_w = None
dtype=data_type, self.norm_b = None
device=device), else:
requires_grad=False) self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, requires_grad=False)
dtype=data_type, self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
device=device), requires_grad=False)
requires_grad=False)
self.layer_past = None self.layer_past = None
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \ self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \
inference_cuda_module.allocate_workspace_fp16 inference_cuda_module.allocate_workspace_fp16
self._alloc_workspace = True
@classmethod @classmethod
def reset_cache(cls): def reset_cache(cls):
...@@ -114,25 +104,25 @@ class DeepSpeedTransformerInference(nn.Module): ...@@ -114,25 +104,25 @@ class DeepSpeedTransformerInference(nn.Module):
# TODO(arashb): 'layer_head_mask' and 'past_key_value' are only added to satisfy the OPT models API. # TODO(arashb): 'layer_head_mask' and 'past_key_value' are only added to satisfy the OPT models API.
# This needs to be redesigned later! # This needs to be redesigned later!
layer_head_mask=None, layer_head_mask=None,
past_key_value=None): past_key_value=None,
**kwargs):
if x is not None: if x is not None:
input = x input = x
if "hidden_states" in kwargs:
input = kwargs["hidden_states"]
input_mask = (input_mask if attn_mask is None else input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask
attn_mask) if attention_mask is None else attention_mask
# Allocate memory only on first layer forward # Allocate memory only on first layer forward
if self.config.layer_id == 0: if self.config.layer_id == 0 and self._alloc_workspace:
self.allocate_workspace(self.config.hidden_size, self.allocate_workspace(self.config.hidden_size, self.config.heads,
self.config.heads,
input.size()[1], input.size()[1],
input.size()[0], input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
DeepSpeedTransformerInference.layer_id,
self.config.mp_size,
self.config.bigscience_bloom, self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.max_out_tokens) self.config.min_out_tokens)
self._alloc_workspace = False
get_present = (get_present or get_key_value or use_cache) get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask input_mask = input_mask if attention_mask is None else attention_mask
...@@ -171,10 +161,7 @@ class DeepSpeedTransformerInference(nn.Module): ...@@ -171,10 +161,7 @@ class DeepSpeedTransformerInference(nn.Module):
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob) output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
if not self.config.pre_layer_norm: if not self.config.pre_layer_norm:
output = inference_cuda_module.layer_norm(output, output = inference_cuda_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
self.norm_w,
self.norm_b,
self.config.epsilon)
output = output.to(input_type) output = output.to(input_type)
if get_present: if get_present:
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
from .module_quantize import quantize_transformer_layer from .module_quantize import quantize_transformer_layer
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Automatic Tensor Parallelism # Automatic Tensor Parallelism
import re import re
...@@ -8,6 +11,7 @@ from .replace_policy import replace_policies ...@@ -8,6 +11,7 @@ from .replace_policy import replace_policies
class AutoTP(): class AutoTP():
def in_module_list(module, module_list): def in_module_list(module, module_list):
for item in module_list: for item in module_list:
if type(item).__name__ == type(module).__name__: if type(item).__name__ == type(module).__name__:
...@@ -28,18 +32,7 @@ class AutoTP(): ...@@ -28,18 +32,7 @@ class AutoTP():
return mlist return mlist
def supported(model): def supported(model):
unsupported = [ unsupported = ['codegen', 'deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
'bloom',
'codegen',
'deberta',
'flaubert',
'fsmt',
'gpt2',
'led',
'longformer',
'xlm',
'xlnet'
]
model = str(model) model = str(model)
key = re.search(r": (.*?)Model", model) key = re.search(r": (.*?)Model", model)
if key is None: if key is None:
...@@ -56,8 +49,7 @@ class AutoTP(): ...@@ -56,8 +49,7 @@ class AutoTP():
for key, submodule in module._modules.items(): for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear): if isinstance(submodule, nn.Linear):
layer_list = layer_list + [parent + "." + key] layer_list = layer_list + [parent + "." + key]
elif isinstance(submodule, elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"] layer_list = layer_list + ["ln"]
else: else:
layer_list = layer_list + AutoTP.get_layers(key, submodule) layer_list = layer_list + AutoTP.get_layers(key, submodule)
...@@ -102,9 +94,7 @@ class AutoTP(): ...@@ -102,9 +94,7 @@ class AutoTP():
for key, submodule in module._modules.items(): for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear): if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key] layer_list = layer_list + ["." + key]
elif isinstance( elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"] layer_list = layer_list + ["ln"]
else: else:
layer_list = layer_list + AutoTP.get_layers(key, submodule) layer_list = layer_list + AutoTP.get_layers(key, submodule)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .bert import DS_BERTContainer, HFBertLayerPolicy from .bert import DS_BERTContainer, HFBertLayerPolicy
from .bloom import DS_BloomContainer, BLOOMLayerPolicy, supported_models from .bloom import DS_BloomContainer, BLOOMLayerPolicy, supported_models
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment