Commit 67ea635f authored by aiss's avatar aiss
Browse files

push dsv0.8.2 version

parent 1b2721ad
Pipeline #201 failed with stages
in 0 seconds
...@@ -7,6 +7,7 @@ per rank for training. ...@@ -7,6 +7,7 @@ per rank for training.
""" """
import os import os
import re
import sys import sys
import json import json
import base64 import base64
...@@ -14,18 +15,21 @@ import argparse ...@@ -14,18 +15,21 @@ import argparse
import subprocess import subprocess
import collections import collections
from copy import deepcopy from copy import deepcopy
import signal
import time
import torch.cuda from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..nebula.constants import NEBULA_EXPORT_ENVS
from ..utils import logger from ..utils import logger
from ..autotuning import Autotuner from ..autotuning import Autotuner
from deepspeed.accelerator import get_accelerator
DLTS_HOSTFILE = "/job/hostfile" DLTS_HOSTFILE = "/job/hostfile"
EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", "UCX"] EXPORT_ENVS = ['MLFLOW', 'NCCL', 'PYTHON', 'MV2', 'UCX']
EXPORT_ENVS += NEBULA_EXPORT_ENVS
DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env" DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env"
DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.'] DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.']
PDSH_MAX_FAN_OUT = 1024 PDSH_MAX_FAN_OUT = 1024
...@@ -75,6 +79,18 @@ def parse_args(args=None): ...@@ -75,6 +79,18 @@ def parse_args(args=None):
help="Total number of worker nodes to run on, this will use " help="Total number of worker nodes to run on, this will use "
"the top N hosts from the given hostfile.") "the top N hosts from the given hostfile.")
parser.add_argument("--min_elastic_nodes",
type=int,
default=-1,
help="Minimum number of nodes to run elastic training on. "
"Default is 1 when elastic training is enabled")
parser.add_argument("--max_elastic_nodes",
type=int,
default=-1,
help="Maximum number of nodes to run elastic training on. "
"Default is num_nodes when elastic training is enabled")
parser.add_argument("--num_gpus", parser.add_argument("--num_gpus",
type=int, type=int,
default=-1, default=-1,
...@@ -93,11 +109,12 @@ def parse_args(args=None): ...@@ -93,11 +109,12 @@ def parse_args(args=None):
help="(optional) IP address of node 0, will be " help="(optional) IP address of node 0, will be "
"inferred via 'hostname -I' if not specified.") "inferred via 'hostname -I' if not specified.")
parser.add_argument("--launcher", parser.add_argument(
default=PDSH_LAUNCHER, "--launcher",
type=str, default=PDSH_LAUNCHER,
help="(optional) choose launcher backend for multi-node " type=str,
"training. Options currently include PDSH, OpenMPI, MVAPICH.") help="(optional) choose launcher backend for multi-node "
"training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH.")
parser.add_argument("--launcher_args", parser.add_argument("--launcher_args",
default="", default="",
...@@ -121,6 +138,10 @@ def parse_args(args=None): ...@@ -121,6 +138,10 @@ def parse_args(args=None):
help="Do not pass local_rank as an argument when calling " help="Do not pass local_rank as an argument when calling "
"the user's training script.") "the user's training script.")
parser.add_argument("--no_ssh_check",
action="store_true",
help="Do not perform ssh check in multi-node launcher model")
parser.add_argument("--force_multi", parser.add_argument("--force_multi",
action="store_true", action="store_true",
help="Force multi-node launcher mode, helps in cases where user " help="Force multi-node launcher mode, helps in cases where user "
...@@ -133,6 +154,12 @@ def parse_args(args=None): ...@@ -133,6 +154,12 @@ def parse_args(args=None):
"where <main-pid> is the pid of the first process that invoked `deepspeed`. " "where <main-pid> is the pid of the first process that invoked `deepspeed`. "
"Useful when launching deepspeed processes programmatically.") "Useful when launching deepspeed processes programmatically.")
parser.add_argument(
"--enable_each_rank_log",
default="None",
type=str,
help="redirect the stdout and stderr from each rank into different log files")
parser.add_argument( parser.add_argument(
"--autotuning", "--autotuning",
default="", default="",
...@@ -142,6 +169,10 @@ def parse_args(args=None): ...@@ -142,6 +169,10 @@ def parse_args(args=None):
help="Run DeepSpeed autotuner to discover optimal configuration parameters " help="Run DeepSpeed autotuner to discover optimal configuration parameters "
"before running job.") "before running job.")
parser.add_argument("--elastic_training",
action="store_true",
help="Enable elastic training support in DeepSpeed.")
parser.add_argument("user_script", parser.add_argument("user_script",
type=str, type=str,
help="User script to launch, followed by any required " help="User script to launch, followed by any required "
...@@ -158,25 +189,45 @@ def fetch_hostfile(hostfile_path): ...@@ -158,25 +189,45 @@ def fetch_hostfile(hostfile_path):
# e.g., worker-0 slots=16 # e.g., worker-0 slots=16
with open(hostfile_path, 'r') as fd: with open(hostfile_path, 'r') as fd:
resource_pool = collections.OrderedDict() hostfile_text = fd.readlines()
for line in fd.readlines():
line = line.strip() return _parse_hostfile(hostfile_text)
if line == '':
# skip empty lines
continue def _parse_hostfile(hostfile_lines):
try: # Regex matches one or more non-whitespace characters (\S+) at the start of
hostname, slots = line.split() # the line, followed by one or more whitespace characters (\s+), followed
_, slot_count = slots.split("=") # by the string "slots=", followed by one or more digits (\d+).
slot_count = int(slot_count) pattern = r'^(\S+)\s+slots=(\d+)'
except ValueError as err:
logger.error("Hostfile is not formatted correctly, unable to " resource_pool = collections.OrderedDict()
"proceed with training.")
raise err for line in hostfile_lines:
if hostname in resource_pool: line = line.strip()
logger.error("Hostfile contains duplicate hosts, unable to " match = re.search(pattern, line)
"proceed with training.") if line.startswith("#") or line == "":
raise ValueError(f"host {hostname} is already defined") # hostfile comment or empty line, ignore
resource_pool[hostname] = slot_count continue
elif match:
host = match.group(1)
num_slots = int(match.group(2))
if host in resource_pool:
logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError(
f"Hostfile contains multiple entries for {host}, unable to proceed with launching"
)
resource_pool[host] = num_slots
else:
logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError(
"Hostfile contains a bad entry: {line}, unable to proceed with launching"
)
if len(resource_pool) == 0:
logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError(
"Hostfile is empty or not formatted correctly, unable to proceed with launching."
)
return resource_pool return resource_pool
...@@ -305,14 +356,33 @@ def run_autotuning(args, active_resources): ...@@ -305,14 +356,33 @@ def run_autotuning(args, active_resources):
tuner.print_tuning_results() tuner.print_tuning_results()
logger.info("[End] Running autotuning") logger.info("[End] Running autotuning")
tuner.write_optimal_config()
if args.autotuning == "run": if args.autotuning == "run":
tuner.run_after_tuning() tuner.run_after_tuning()
def parse_num_nodes(str_num_nodes: str, elastic_training: bool):
node_list = str_num_nodes.split(":")
if len(node_list) == 1:
min_nodes, max_nodes = int(node_list[0]), -1
elif len(node_list) == 2 and elastic_training:
min_nodes, max_nodes = int(node_list[0]), int(node_list[1])
elif len(node_list) == 2 and not elastic_training:
raise RuntimeError("MIN:MAX format is only supported in elastic training")
else:
raise RuntimeError("num_nodes {} is not in MIN:MAX format".format(str_num_nodes))
return min_nodes, max_nodes
def main(args=None): def main(args=None):
args = parse_args(args) args = parse_args(args)
if args.elastic_training:
assert args.master_addr != "", "Master Addr is required when elastic training is enabled"
resource_pool = fetch_hostfile(args.hostfile) resource_pool = fetch_hostfile(args.hostfile)
# respect CUDA_VISIBLE_DEVICES for a single node and no explicit resource filters # respect CUDA_VISIBLE_DEVICES for a single node and no explicit resource filters
...@@ -336,7 +406,7 @@ def main(args=None): ...@@ -336,7 +406,7 @@ def main(args=None):
multi_node_exec = True multi_node_exec = True
if not resource_pool: if not resource_pool:
resource_pool = {} resource_pool = {}
device_count = torch.cuda.device_count() device_count = get_accelerator().device_count()
if device_count == 0: if device_count == 0:
raise RuntimeError("Unable to proceed, no GPU resources available") raise RuntimeError("Unable to proceed, no GPU resources available")
resource_pool['localhost'] = device_count resource_pool['localhost'] = device_count
...@@ -352,7 +422,7 @@ def main(args=None): ...@@ -352,7 +422,7 @@ def main(args=None):
env = os.environ.copy() env = os.environ.copy()
# validate that passwordless-ssh is workly properly with this hostfile # validate that passwordless-ssh is workly properly with this hostfile
if multi_node_exec: if multi_node_exec and not args.no_ssh_check:
first_host = list(active_resources.keys())[0] first_host = list(active_resources.keys())[0]
try: try:
subprocess.check_call( subprocess.check_call(
...@@ -369,8 +439,18 @@ def main(args=None): ...@@ -369,8 +439,18 @@ def main(args=None):
assert multi_node_exec assert multi_node_exec
first_host = list(active_resources.keys())[0] first_host = list(active_resources.keys())[0]
hostname_cmd = [f"ssh {first_host} hostname -I"] hostname_cmd = [f"ssh {first_host} hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True) try:
result = subprocess.check_output(hostname_cmd, shell=True)
except subprocess.CalledProcessError as err:
logger.error(
"Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr"
)
raise err
args.master_addr = result.decode('utf-8').split()[0] args.master_addr = result.decode('utf-8').split()[0]
if not args.master_addr:
raise RuntimeError(
f"Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr"
)
logger.info(f"Using IP address of {args.master_addr} for node {first_host}") logger.info(f"Using IP address of {args.master_addr} for node {first_host}")
if args.autotuning != "": if args.autotuning != "":
...@@ -391,6 +471,9 @@ def main(args=None): ...@@ -391,6 +471,9 @@ def main(args=None):
updated_active_resources[hostname] = list(range(args.num_gpus)) updated_active_resources[hostname] = list(range(args.num_gpus))
active_resources = updated_active_resources active_resources = updated_active_resources
if args.elastic_training:
assert not args.no_local_rank, "--no_local_rank argument is not supported in Elastic training"
# encode world info as base64 to make it easier to pass via command line # encode world info as base64 to make it easier to pass via command line
world_info_base64 = encode_world_info(active_resources) world_info_base64 = encode_world_info(active_resources)
...@@ -414,6 +497,13 @@ def main(args=None): ...@@ -414,6 +497,13 @@ def main(args=None):
deepspeed_launch.append("--no_local_rank") deepspeed_launch.append("--no_local_rank")
if args.save_pid: if args.save_pid:
deepspeed_launch += ["--save_pid", f"{os.getpid()}"] deepspeed_launch += ["--save_pid", f"{os.getpid()}"]
if args.enable_each_rank_log:
deepspeed_launch.append(
f"--enable_each_rank_log={args.enable_each_rank_log}")
if args.elastic_training:
deepspeed_launch.append("--enable_elastic_training")
deepspeed_launch.append(f"--max_elastic_nodes={args.max_elastic_nodes}")
deepspeed_launch.append(f"--min_elastic_nodes={args.min_elastic_nodes}")
cmd = deepspeed_launch + [args.user_script] + args.user_args cmd = deepspeed_launch + [args.user_script] + args.user_args
else: else:
args.launcher = args.launcher.lower() args.launcher = args.launcher.lower()
...@@ -421,8 +511,12 @@ def main(args=None): ...@@ -421,8 +511,12 @@ def main(args=None):
runner = PDSHRunner(args, world_info_base64) runner = PDSHRunner(args, world_info_base64)
elif args.launcher == OPENMPI_LAUNCHER: elif args.launcher == OPENMPI_LAUNCHER:
runner = OpenMPIRunner(args, world_info_base64, resource_pool) runner = OpenMPIRunner(args, world_info_base64, resource_pool)
elif args.launcher == MPICH_LAUNCHER:
runner = MPICHRunner(args, world_info_base64, resource_pool)
elif args.launcher == MVAPICH_LAUNCHER: elif args.launcher == MVAPICH_LAUNCHER:
runner = MVAPICHRunner(args, world_info_base64, resource_pool) runner = MVAPICHRunner(args, world_info_base64, resource_pool)
elif args.launcher == SLURM_LAUNCHER:
runner = SlurmRunner(args, world_info_base64, resource_pool)
else: else:
raise NotImplementedError(f"Unknown launcher {args.launcher}") raise NotImplementedError(f"Unknown launcher {args.launcher}")
...@@ -448,11 +542,26 @@ def main(args=None): ...@@ -448,11 +542,26 @@ def main(args=None):
key, val = var.split('=', maxsplit=1) key, val = var.split('=', maxsplit=1)
runner.add_export(key, val) runner.add_export(key, val)
cmd = runner.get_cmd(env, active_resources) if args.launcher == PDSH_LAUNCHER:
cmd, kill_cmd = runner.get_cmd(env, active_resources)
else:
cmd = runner.get_cmd(env, active_resources)
logger.info(f"cmd = {' '.join(cmd)}") logger.info(f"cmd = {' '.join(cmd)}")
result = subprocess.Popen(cmd, env=env) result = subprocess.Popen(cmd, env=env)
def sigkill_handler(signum, frame):
result.send_signal(signal.SIGINT)
time.sleep(0.1)
result.send_signal(signal.SIGTERM)
result_kill = subprocess.Popen(kill_cmd, env=env)
result_kill.wait()
time.sleep(1)
sys.exit(1)
if args.launcher == PDSH_LAUNCHER:
signal.signal(signal.SIGINT, sigkill_handler)
result.wait() result.wait()
# In case of failure must propagate the error-condition back to the caller (usually shell). The # In case of failure must propagate the error-condition back to the caller (usually shell). The
......
'''Copyright The Microsoft DeepSpeed Team'''
from .transformers.ds_transformer import DeepSpeedTransformerInference
from .transformers.clip_encoder import DSClipEncoder
'''Copyright The Microsoft DeepSpeed Team'''
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from ..features.cuda_graph import CUDAGraph
class DSUNet(CUDAGraph, torch.nn.Module):
def __init__(self, unet, enable_cuda_graph=True):
super().__init__(enable_cuda_graph=enable_cuda_graph)
self.unet = unet
# SD pipeline accesses this attribute
self.in_channels = unet.in_channels
self.device = self.unet.device
self.dtype = self.unet.dtype
self.config = self.unet.config
self.fwd_count = 0
self.unet.requires_grad_(requires_grad=False)
self.unet.to(memory_format=torch.channels_last)
self.cuda_graph_created = False
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._cuda_graphs.replay()
return self.static_output
def forward(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
return outputs
else:
return self._forward(*inputs, **kwargs)
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._forward(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs = torch.cuda.CUDAGraph()
self.static_inputs = inputs
self.static_kwargs = kwargs
with torch.cuda.graph(self._cuda_graphs):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
self.cuda_graph_created = True
def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True):
return self.unet(sample, timestamp, encoder_hidden_states, return_dict)
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from ..features.cuda_graph import CUDAGraph
class DSVAE(CUDAGraph, torch.nn.Module):
def __init__(self, vae, enable_cuda_graph=True):
super().__init__(enable_cuda_graph=enable_cuda_graph)
self.vae = vae
self.device = self.vae.device
self.dtype = self.vae.dtype
self.vae.requires_grad_(requires_grad=False)
self.decoder_cuda_graph_created = False
self.encoder_cuda_graph_created = False
self.all_cuda_graph_created = False
def _graph_replay_decoder(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_decoder_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_decoder_kwargs[k].copy_(kwargs[k])
self._decoder_cuda_graph.replay()
return self.static_decoder_output
def _decode(self, x, return_dict=True):
return self.vae.decode(x, return_dict=return_dict)
def _create_cuda_graph_decoder(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._decode(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._decoder_cuda_graph = torch.cuda.CUDAGraph()
self.static_decoder_inputs = inputs
self.static_decoder_kwargs = kwargs
with torch.cuda.graph(self._decoder_cuda_graph):
self.static_decoder_output = self._decode(*self.static_decoder_inputs,
**self.static_decoder_kwargs)
self.decoder_cuda_graph_created = True
def decode(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.decoder_cuda_graph_created:
outputs = self._graph_replay_decoder(*inputs, **kwargs)
else:
self._create_cuda_graph_decoder(*inputs, **kwargs)
outputs = self._graph_replay_decoder(*inputs, **kwargs)
return outputs
else:
return self._decode(*inputs, **kwargs)
def _graph_replay_encoder(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_encoder_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_encoder_kwargs[k].copy_(kwargs[k])
self._encoder_cuda_graph.replay()
return self.static_encoder_output
def _encode(self, x, return_dict=True):
return self.vae.encode(x, return_dict=return_dict)
def _create_cuda_graph_encoder(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._encode(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._encoder_cuda_graph = torch.cuda.CUDAGraph()
self.static_encoder_inputs = inputs
self.static_encoder_kwargs = kwargs
with torch.cuda.graph(self._encoder_cuda_graph):
self.static_encoder_output = self._encode(*self.static_encoder_inputs,
**self.static_encoder_kwargs)
self.encoder_cuda_graph_created = True
def encode(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.encoder_cuda_graph_created:
outputs = self._graph_replay_encoder(*inputs, **kwargs)
else:
self._create_cuda_graph_encoder(*inputs, **kwargs)
outputs = self._graph_replay_encoder(*inputs, **kwargs)
return outputs
else:
return self._encode(*inputs, **kwargs)
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._all_cuda_graph.replay()
return self.static_output
def forward(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
return outputs
else:
return self._forward(*inputs, **kwargs)
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._forward(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._all_cuda_graph = torch.cuda.CUDAGraph()
self.static_inputs = inputs
self.static_kwargs = kwargs
with torch.cuda.graph(self._all_cuda_graph):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
self.all_cuda_graph_created = True
def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True):
return self.vae(sample, timestamp, encoder_hidden_states, return_dict)
'''Copyright The Microsoft DeepSpeed Team'''
'''
Copyright 2023 The Microsoft DeepSpeed Team
'''
from abc import ABC, abstractmethod
class CUDAGraph(ABC):
def __init__(self, enable_cuda_graph=False):
super().__init__()
self.enable_cuda_graph = enable_cuda_graph
@abstractmethod
def _create_cuda_graph(self):
"""
Create CUDA graph(s)
"""
raise NotImplementedError
@abstractmethod
def _graph_replay(self):
"""
Replay CUDA graph(s)
"""
raise NotImplementedError
'''Copyright The Microsoft DeepSpeed Team'''
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from deepspeed.accelerator import get_accelerator
from ..features.cuda_graph import CUDAGraph
class DSClipEncoder(CUDAGraph, torch.nn.Module):
def __init__(self, enc, enable_cuda_graph=False):
super().__init__(enable_cuda_graph=enable_cuda_graph)
enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask
self.enc = enc
self.device = self.enc.device
self.dtype = self.enc.dtype
self.cuda_graph_created = [False, False]
self.static_inputs = [None, None]
self.static_kwargs = [None, None]
self.static_output = [None, None]
self._cuda_graphs = [None, None]
self.iter = 0
self.config = self.enc.config
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
mask = torch.empty(bsz,
seq_len,
seq_len,
dtype=dtype,
device=get_accelerator().current_device_name())
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1)
mask = mask.unsqueeze(1)
return mask
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[self.iter][i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[self.iter][k].copy_(kwargs[k])
self._cuda_graphs[self.iter].replay()
return self.static_output[self.iter]
def forward(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.cuda_graph_created[self.iter]:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
self.iter = (self.iter + 1) % 2
return outputs
else:
return self.enc(*inputs, **kwargs)
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._forward(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph()
self.static_inputs[self.iter] = inputs
self.static_kwargs[self.iter] = kwargs
with torch.cuda.graph(self._cuda_graphs[self.iter]):
self.static_output[self.iter] = self._forward(
*self.static_inputs[self.iter],
**self.static_kwargs[self.iter])
self.cuda_graph_created[self.iter] = True
def _forward(self, *inputs, **kwargs):
return self.enc(*inputs, **kwargs)
'''Copyright The Microsoft DeepSpeed Team'''
import torch.nn as nn
class DeepSpeedTransformerBase(nn.module):
def __init__(self):
pass
# this would be the new clean base class that will replace DeepSpeedTransformerInference.
# we currently don't know how this will look like but keeping it here as a placeholder.
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
class DeepSpeedBERTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed BERT Transformer Layer.
"""
def __init__(self,
config,
mp_group=None,
quantize_scales=None,
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
class DeepSpeedBloomInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed Bloom Transformer Layer.
"""
def __init__(self,
config,
mp_group=None,
quantize_scales=None,
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
class DeepSpeedGPTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed GPT Transformer Layer.
"""
def __init__(self,
config,
mp_group=None,
quantize_scales=None,
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
class DeepSpeedMegatronGPTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed Megatron GPT Transformer Layer.
"""
def __init__(self,
config,
mp_group=None,
quantize_scales=None,
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
class DeepSpeedOPTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed OPT Transformer Layer.
"""
def __init__(self,
config,
mp_group=None,
quantize_scales=None,
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
import torch.nn as nn
from deepspeed import comm as dist
from deepspeed.utils.logging import log_dist
from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
inference_cuda_module = None
class DeepSpeedTransformerInference(nn.Module):
"""Initialize the DeepSpeed Transformer Layer.
Arguments:
layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
layer_id will be 0,1,2...23 when each layer object is instantiated
config: An object of DeepSpeedInferenceConfig
mp_group: Model parallelism group initialized on the modeling side.
quantize_scales: This argument groups all the layers' scales used for quantization
quantize_groups: Number of groups used for quantizing the model
merge_count: Shows the number of model-parallel checkpoints merged before running inference.
We use this argument to control the quantization scale for the model parameters if a bigger
quantize-grouping than 1 is used.
mlp_extra_grouping: This flag is used to show a 2x higher number of groups used for the MLP part
of a Transformer layer. We use this feature for quantization to reduce the convergence impact
for specific downstream tasks.
"""
layer_id = 0
def __init__(self,
config,
mp_group=None,
quantize_scales=None,
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super(DeepSpeedTransformerInference, self).__init__()
self.config = config
self.config.layer_id = DeepSpeedTransformerInference.layer_id
DeepSpeedTransformerInference.layer_id += 1
data_type = torch.half if config.fp16 else torch.float
global inference_cuda_module
if inference_cuda_module is None:
builder = InferenceBuilder()
inference_cuda_module = builder.load()
if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
if self.config.bigscience_bloom:
self.attention = BloomSelfAttention(self.config,
mp_group,
quantize_scales,
quantize_groups,
merge_count)
else:
self.attention = DeepSpeedSelfAttention(self.config,
mp_group,
quantize_scales,
quantize_groups,
merge_count)
self.mlp = DeepSpeedMLP(self.config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
device = get_accelerator().current_device_name(
) # if config.bigscience_bloom else 'cpu'
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.layer_past = None
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \
inference_cuda_module.allocate_workspace_fp16
@classmethod
def reset_cache(cls):
if inference_cuda_module is not None:
inference_cuda_module.reset_cache()
def forward(
self,
input=None,
input_mask=None,
attention_mask=None,
attn_mask=None,
head_mask=None,
layer_past=None,
get_key_value=False,
get_present=False,
encoder_output=None,
enc_dec_attn_mask=None,
x=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
alibi=None,
output_attentions=False,
# TODO(arashb): 'layer_head_mask' and 'past_key_value' are only added to satisfy the OPT models API.
# This needs to be redesigned later!
layer_head_mask=None,
past_key_value=None):
if x is not None:
input = x
input_mask = (input_mask if attn_mask is None else
attn_mask) if attention_mask is None else attention_mask
# Allocate memory only on first layer forward
if self.config.layer_id == 0:
self.allocate_workspace(self.config.hidden_size,
self.config.heads,
input.size()[1],
input.size()[0],
DeepSpeedTransformerInference.layer_id,
self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0,
self.config.max_out_tokens)
get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask
# We set the prev key/value to None when there is a prompt
if input.shape[1] > 1:
self.layer_past = None
layer_past = layer_past if layer_past is not None else self.layer_past
head_mask = layer_head_mask if layer_head_mask is not None else head_mask
attn_mask = None
if isinstance(input, tuple):
attn_mask = input[1]
input = input[0]
input_type = input.dtype
if (self.config.fp16 or self.config.q_int8) \
and input.dtype == torch.float:
input = input.half()
with torch.no_grad():
attention_output, key, value, context_outputtn_ctx, inp_norm = \
self.attention(input,
input_mask,
head_mask,
layer_past,
get_present,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
self.norm_w,
self.norm_b,
alibi)
presents = (key, value)
self.layer_past = presents if layer_past is None else None
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
if not self.config.pre_layer_norm:
output = inference_cuda_module.layer_norm(output,
self.norm_w,
self.norm_b,
self.config.epsilon)
output = output.to(input_type)
if get_present:
output = (output, presents)
if self.config.return_single_tuple:
return (output, )
elif self.config.return_tuple:
return output if type(output) is tuple else (output, attn_mask)
else:
return output
from .replace_module import replace_transformer_layer, revert_transformer_layer '''Copyright The Microsoft DeepSpeed Team'''
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
from .module_quantize import quantize_transformer_layer from .module_quantize import quantize_transformer_layer
from .replace_policy import DSPolicy, HFBertLayerPolicy from .replace_policy import HFBertLayerPolicy
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
from .policy import DSPolicy
'''Copyright The Microsoft DeepSpeed Team'''
# Automatic Tensor Parallelism
import re
from torch import nn
from .replace_policy import replace_policies
class AutoTP():
def in_module_list(module, module_list):
for item in module_list:
if type(item).__name__ == type(module).__name__:
return True
return False
def get_module_list(model):
mlist = []
for child in model.children():
if isinstance(child, nn.ModuleList):
for module in child.children():
if not mlist:
mlist = [module]
elif not AutoTP.in_module_list(module, mlist):
mlist = mlist + [module]
else:
mlist = mlist + AutoTP.get_module_list(child)
return mlist
def supported(model):
unsupported = [
'bloom',
'codegen',
'deberta',
'flaubert',
'fsmt',
'gpt2',
'led',
'longformer',
'xlm',
'xlnet'
]
model = str(model)
key = re.search(r": (.*?)Model", model)
if key is None:
key = re.search(r": (.*?)Stack", model)
if key is None:
key = re.match(r"(.*?)Model", model)
assert key is not None, "Not able to determine model policy automatically. Please provide policy."
if key.group(1).lower() in unsupported:
return False
return True
def get_layers(parent, module):
layer_list = []
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + [parent + "." + key]
elif isinstance(submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
return layer_list
def update_policy_list(policy_list, new_module, new_gems):
if len(policy_list):
for i, policy in enumerate(policy_list):
# if module already exists in policy, combine gems and remove duplicates
if policy[0] == type(new_module):
new_gems = set(new_gems + policy[1])
policy_list[i] = tuple([type(new_module), new_gems])
return policy_list
policy_list.append(tuple([type(new_module), new_gems]))
return policy_list
def kernel_supported(module_list):
policy = []
for plcy in replace_policies:
# instantiate a throw-away policy in order to populate the _orig_layer_class
_ = plcy(None)
if isinstance(plcy._orig_layer_class, list):
for orig_layer_class in plcy._orig_layer_class:
policy.append(orig_layer_class)
elif plcy._orig_layer_class is not None:
policy.append(plcy._orig_layer_class)
for child in module_list:
if child.__class__ in policy:
return True
return False
def tp_parser(model):
policy_list = []
module_list = []
layer_list = []
gem_list = []
module_list = AutoTP.get_module_list(model)
assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
for module in module_list:
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(
submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
for i, layer in enumerate(layer_list):
if layer == 'ln':
if layer_list[i - 1] != 'ln':
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
layer_list = []
if gem_list != []:
gem_list = list(set(gem_list))
policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
gem_list = []
assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
return policy_list
'''Copyright The Microsoft DeepSpeed Team'''
from .bert import DS_BERTContainer, HFBertLayerPolicy
from .bloom import DS_BloomContainer, BLOOMLayerPolicy, supported_models
from .distil_bert import DS_DistilBERTContainer, HFDistilBertLayerPolicy
from .gpt2 import DS_GPT2Container, HFGPT2LayerPolicy
from .gptj import DS_GPTJContainer, HFGPTJLayerPolicy
from .gptneo import DS_GPTNEOContainer, HFGPTNEOLayerPolicy
from .gptneox import DS_GPTNEOXContainer, GPTNEOXLayerPolicy
from .megatron_gpt import DS_MegatronGPTContainer, MegatronLayerPolicy
from .megatron_gpt_moe import DS_MegatronGPTMoEContainer, MegatronMoELayerPolicy
from .opt import DS_OPTContainer, HFOPTLayerPolicy
from .clip import DS_CLIPContainer, HFCLIPLayerPolicy
from .unet import UNetPolicy
from .vae import VAEPolicy
'''Copyright The Microsoft DeepSpeed Team'''
# Create a container object to save model-specific tensors using the policy file above.
from abc import ABC
import torch
from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig
from deepspeed.accelerator import get_accelerator
class BaseConvolutionContainer(ABC):
# not implemented
def __init__(self):
pass
class BaseTransformerContainer(ABC):
def __init__(self, policy, config, model_config, layer_id, child):
self.policy = policy
self.config = config
self.model_config = model_config
self.layer_id = layer_id
self.child = child
self.megatron_v2 = self.policy.is_megatron_v2
self.scale_attention = self.policy.scale_attention
self.ckpt_load_enabled = False
# configuration for models. todo: can this be moved to a pydantic model config?
self.hidden_size = None
self.num_attention_heads = None
self.mp_size = self.config.tensor_parallel.tp_size
self.pre_layer_norm = self.policy.pre_attn_norm
self.fp16 = False
self.attn_linear_layer = self.policy.linear_layer
self.mlp_linear_layer = self.policy.linear_layer
self.layer_norm_eps = self.model_config.layer_norm_eps if \
hasattr(self.model_config, 'layer_norm_eps') else (self.model_config.layer_norm_epsilon if \
hasattr(self.model_config, 'layer_norm_epsilon') else self.model_config.layernorm_epsilon if \
hasattr(self.model_config, 'layernorm_epsilon') else 1.0e-12)
self.return_tuple = self.config.return_tuple
self.triangular_masking = True
self.local_attention = ((self.model_config.attention_layers[self.layer_id]
== "local") if hasattr(self.model_config,
'attention_layers') else False)
self.window_size = getattr(self.model_config, "window_size", 1)
self.mlp_act_func_type = self.policy.mlp_act_func_type
self.training_mp_size = self.config.training_mp_size
self.bigscience_bloom = False
self.max_out_tokens = self.config.max_out_tokens
self.scale_attn_by_inverse_layer_idx = getattr(
self.config,
"scale_attn_by_inverse_layer_idx",
False)
self.use_mup = self.policy.use_mup
self.return_single_tuple = False
self.rotary_dim = self.model_config.rotary_dim if hasattr(self.model_config, 'rotary_dim') \
else self.child.attention.rotary_ndims if \
hasattr(self.child, 'attention') and hasattr(self.child.attention,'rotary_ndims') else -1
self.mlp_after_attn = (self.rotary_dim is None or self.rotary_dim < 0)
# Attention tensors
self.qkvw = None
self.qkvb = None
self.dense_w = None
self.dense_b = None
# MLP tensors
self._h4h_w = None
self._h4h_b = None
self._4hh_w = None
self._4hh_b = None
# LayerNorm tensors
self.attn_nw = None
self.attn_nb = None
self.input_nw = None
self.input_nb = None
def create_ds_model_config(self):
self.set_hidden_heads(*self.policy.get_hidden_heads())
assert self.num_attention_heads % self.mp_size == 0,\
"To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
"This is because the attention computation is partitioned evenly among the parallel GPUs."
self.ds_model_config = DeepSpeedInferenceConfig(
hidden_size=self.hidden_size,
heads=self.num_attention_heads,
layer_norm_eps=self.layer_norm_eps,
fp16=self.fp16,
pre_layer_norm=self.pre_layer_norm,
mp_size=self.mp_size,
q_int8=self.quantize,
return_tuple=self.return_tuple,
triangular_masking=self.triangular_masking,
local_attention=self.local_attention,
window_size=self.window_size,
rotary_dim=self.rotary_dim,
mlp_after_attn=self.mlp_after_attn,
mlp_act_func_type=self.mlp_act_func_type,
training_mp_size=self.training_mp_size,
bigscience_bloom=self.bigscience_bloom,
max_out_tokens=self.max_out_tokens,
scale_attn_by_inverse_layer_idx=self.scale_attn_by_inverse_layer_idx,
use_mup=self.use_mup,
return_single_tuple=self.return_single_tuple,
)
return self.ds_model_config
def initialize_tensors(self):
# Set the tensors from policy (user module) to container (DS module)
self.set_attention(*self.policy.attention())
self.set_mlp(*self.policy.mlp())
self.set_layernorm(*self.policy.layernorm())
def convert_to_required_dtype(self, dtype):
# Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy
if dtype == torch.half:
for k, v in self.__dict__.items():
# The list comprehension is used for MoE tensor lists
if isinstance(v, list) and all((isinstance(tensor, torch.Tensor) \
or isinstance(tensor, torch.nn.Parameter)) for tensor in v):
self.__dict__[k] = [moe_tensor.half() for moe_tensor in v]
if isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter):
self.__dict__[k] = v.half()
def set_dtype(self, fp16=False):
self.fp16 = fp16
def set_moe(self, moe=False):
self.moe = moe
def set_tensor_parallel_config(self, mp_size, mp_group):
self.mp_size = mp_size
self.mp_group = mp_group
def set_quantization_config(self, quantize, quantizer):
self.quantize = quantize
self.quantizer = quantizer
def set_hidden_heads(self, hidden_size, num_attention_heads):
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
def set_attention(self, qkvw, qkvb, dense_w, dense_b):
self.qkvw = qkvw
self.qkvb = qkvb
self.dense_w = dense_w
self.dense_b = dense_b
def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b):
self._h4h_w = _h4h_w
self._h4h_b = _h4h_b
self._4hh_w = _4hh_w
self._4hh_b = _4hh_b
def set_layernorm(self, attn_nw, attn_nb, input_nw, input_nb):
self.attn_nw = attn_nw
self.attn_nb = attn_nb
self.input_nw = input_nw
self.input_nb = input_nb
def apply_weight_quantization(self):
# quantize attention weights
self.attention_quantization()
# quantize mlp weights
self.mlp_quantization()
def attention_quantization(self):
self.module.attention.attn_qkvw = self.quantizer.quantize(
self.module.attention.attn_qkvw)
self.module.attention.attn_ow = self.quantizer.quantize(
self.module.attention.attn_ow)
def mlp_quantization(self):
self.module.mlp.inter_w = self.quantizer.quantize(self.module.mlp.inter_w)
self.module.mlp.output_w = self.quantizer.quantize(self.module.mlp.output_w)
def apply_tensor_parallelism(self, mp_replace):
# setup the new Attention module
self.attention_qkv_mp(mp_replace)
self.attention_o_mp(mp_replace)
# setup the new MLP module
self.mlp_inter_mp(mp_replace)
self.mlp_output_mp(mp_replace)
# Apply weight quantization
self.apply_weight_quantization()
def attention_qkv_mp(self, mp_replace):
self.module.attention.attn_qkvw = mp_replace.qkv_copy(
self.module.attention.attn_qkvw,
self.qkvw)
self.module.attention.attn_qkvb = mp_replace.qkv_copy(
self.module.attention.attn_qkvb,
self.qkvb)
def attention_o_mp(self, mp_replace):
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow,
self.dense_w)
self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob,
self.dense_b)
def mlp_inter_mp(self, mp_replace):
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w)
self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b)
def mlp_output_mp(self, mp_replace):
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w)
self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b, self._4hh_b)
def copy_data_to_new_module(self):
if self.attn_nw is None:
self.module.mlp.attn_nw = self.attn_nw
self.module.mlp.attn_nb = self.attn_nb
else:
self.module.mlp.attn_nw.data.copy_(
self.attn_nw.to(get_accelerator().current_device_name()))
self.module.mlp.attn_nb.data.copy_(
self.attn_nb.to(get_accelerator().current_device_name()))
self.module.norm_w.data.copy_(
self.input_nw.to(get_accelerator().current_device_name()))
self.module.norm_b.data.copy_(
self.input_nb.to(get_accelerator().current_device_name()))
def transpose(self):
self.transpose_attention()
self.transpose_mlp()
def transpose_attention(self):
if self.attn_linear_layer:
self.qkvw = self.transpose_impl(self.qkvw.data)
self.dense_w = self.transpose_impl(self.dense_w.data)
def transpose_mlp(self):
if self.mlp_linear_layer:
self._h4h_w = self.transpose_impl(self._h4h_w.data)
self._4hh_w = self.transpose_impl(self._4hh_w.data)
def transpose_impl(self, data):
data = data.contiguous()
data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
data = data.reshape(data.shape[-1], data.shape[-2])
data.to(get_accelerator().current_device_name())
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment