Commit 7d1a83a9 authored by aiss's avatar aiss
Browse files

push Deepspeed 0.6.3 rocm version

parent ab5534fc
# Copyright 2020 The Microsoft DeepSpeed Team
"""
DeepSpeed launcher, this is similar to torch.distributed.launch but supports
additional features such as abitrary gpu exclusion.
additional features such as arbitrary gpu exclusion.
deepspeed.launcher.launch is intended to be run on a single worker node and
will spawn several worker sub-processes depending on how many devices/ranks
......@@ -21,6 +21,8 @@ from argparse import ArgumentParser, REMAINDER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..utils import logger
PID_FILE_BASEPATH = "/tmp"
def parse_args():
parser = ArgumentParser(description="DeepSpeed distributed training launch"
......@@ -51,6 +53,27 @@ def parse_args():
type=str,
help="world info base64 encoded dictionary")
parser.add_argument("--module",
action="store_true",
help="Change each process to interpret the launch "
"script as a Python module, executing with the same "
"behavior as 'python -m'.")
parser.add_argument("--no_python",
action="store_true",
help="Skip prepending the training script with "
"'python' - just execute it directly.")
parser.add_argument("--no_local_rank",
action="store_true",
help="Do not pass local_rank as an argument when calling "
"the user's training script.")
parser.add_argument("--save_pid",
type=int,
default=0,
help="main launching process pid, for internal pid tracking")
# positional
parser.add_argument("training_script",
type=str,
......@@ -70,23 +93,21 @@ def main():
for k in current_env.keys():
if "NCCL" in k:
logger.info("%s %s %s", args.node_rank, k, current_env[k])
logger.info(f"{args.node_rank} {k}={current_env[k]}")
world_info = None
assert args.world_info != "None", "must provide world info dict"
if args.world_info == "None":
raise ValueError("world_info can not be None")
world_info = base64.urlsafe_b64decode(args.world_info)
world_info = json.loads(world_info)
logger.info("WORLD INFO DICT: {}".format(world_info))
logger.info(f"WORLD INFO DICT: {world_info}")
node_list = list(world_info.keys())
args.nnodes = len(node_list)
local_node = node_list[args.node_rank]
local_gpu_ids = world_info[local_node]
num_local_procs = len(local_gpu_ids)
logger.info(
"nnodes={}, num_local_procs={}, node_rank={}".format(args.nnodes,
num_local_procs,
args.node_rank),
f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}"
)
global_rank_mapping = defaultdict(list)
......@@ -98,19 +119,32 @@ def main():
for gid in gids:
global_rank_mapping[node_id].append(curr_global_rank)
curr_global_rank += 1
logger.info("global_rank_mapping={}".format(global_rank_mapping))
logger.info("dist_world_size={}".format(dist_world_size))
logger.info(f"global_rank_mapping={global_rank_mapping}")
logger.info(f"dist_world_size={dist_world_size}")
current_env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, local_gpu_ids))
logger.info("Setting CUDA_VISIBLE_DEVICES={}".format(
current_env["CUDA_VISIBLE_DEVICES"]))
exclusion_counts_per_node = None
logger.info(f"Setting CUDA_VISIBLE_DEVICES={current_env['CUDA_VISIBLE_DEVICES']}")
# set PyTorch distributed related environmental variables
current_env["MASTER_ADDR"] = args.master_addr
current_env["MASTER_PORT"] = str(args.master_port)
current_env["WORLD_SIZE"] = str(dist_world_size)
current_env["CROSS_RANK"] = str(args.node_rank)
current_env["CROSS_SIZE"] = str(args.nnodes)
current_env["LOCAL_SIZE"] = str(num_local_procs)
if args.save_pid:
print(f"launcher pid: {os.getpid()}")
pid_file = None
if args.save_pid:
launcher_pid = os.getpid()
pid_file = os.path.join(PID_FILE_BASEPATH, f"{args.save_pid}.deepspeed")
assert not os.path.isfile(pid_file), "pid file exists but shouldn't"
with open(pid_file, 'w') as fd:
fd.write(f"{launcher_pid}")
processes = []
cmd = []
for local_rank in range(0, num_local_procs):
# each process's rank
dist_rank = global_rank_mapping[local_node][local_rank]
......@@ -118,36 +152,48 @@ def main():
current_env["LOCAL_RANK"] = str(local_rank)
# spawn the processes
cmd = [
sys.executable,
"-u",
args.training_script,
"--local_rank={}".format(local_rank)
] + args.training_script_args
cmd = []
if not args.no_python:
cmd = [sys.executable, "-u"]
if args.module:
cmd.append("-m")
else:
if args.module:
raise ValueError("Don't use both the '--no_python' flag"
" and the '--module' flag at the same time.")
cmd.append(args.training_script)
# A user may not want to pass local_rank as a keyword arg so we make this optional.
if not args.no_local_rank:
cmd.append(f"--local_rank={local_rank}")
cmd += args.training_script_args
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
sig_names = {2: "SIGINT", 15: "SIGTERM"}
last_return_code = None
def sigkill_handler(signum, frame):
for process in processes:
print(f"Killing subprocess {process.pid}")
logger.info(f"Killing subprocess {process.pid}")
try:
process.kill()
except Exception as e:
except Exception:
pass
if last_return_code is not None:
raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
logger.error(f"{cmd} exits with return code = {last_return_code}")
sys.exit(last_return_code)
if signum in sig_names:
print(f"Main process received {sig_names[signum]}, exiting")
logger.info(f"Main process received {sig_names[signum]}, exiting")
if args.save_pid:
if os.path.isfile(pid_file):
os.remove(pid_file)
sys.exit(1)
# pass SIGINT/SIGTERM to children if the parent is being terminated
signal.signal(signal.SIGINT, sigkill_handler)
signal.signal(signal.SIGTERM, sigkill_handler)
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
alive_processes = set(processes)
while len(alive_processes):
finished_processes = []
......@@ -161,6 +207,7 @@ def main():
sigkill_handler(signal.SIGTERM, None) # not coming back
else:
# exited cleanly
logger.info(f"Process {process.pid} exits successfully.")
finished_processes.append(process)
alive_processes = set(alive_processes) - set(finished_processes)
......
......@@ -3,6 +3,7 @@ import sys
import shutil
import subprocess
import warnings
from shlex import quote
from abc import ABC, abstractmethod
from ..utils import logger
......@@ -12,6 +13,7 @@ from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE
class MultiNodeRunner(ABC):
def __init__(self, args, world_info_base64):
self.args = args
self.validate_args()
self.user_arguments = self.parse_user_args()
self.user_script = args.user_script
self.world_info_base64 = world_info_base64
......@@ -19,11 +21,11 @@ class MultiNodeRunner(ABC):
@abstractmethod
def backend_exists(self):
pass
"""Return whether the corresponding backend exists"""
@abstractmethod
def get_cmd(self, environment, active_resources):
pass
"""Return the command to execute on node"""
def add_export(self, key, var):
self.exports[key.strip()] = var.strip()
......@@ -31,6 +33,14 @@ class MultiNodeRunner(ABC):
def parse_user_args(self):
return self.args.user_args
@property
def name(self):
"""Return the name of the backend"""
return self.__class__.__name__
def validate_args(self):
"""Validate self.args"""
class PDSHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64):
......@@ -39,9 +49,13 @@ class PDSHRunner(MultiNodeRunner):
def backend_exists(self):
return shutil.which('pdsh')
@property
def name(self):
return "pdsh"
def parse_user_args(self):
return list(
map(lambda x: x if x.startswith("-") else "'{}'".format(x),
map(lambda x: x if x.startswith("-") else f"'{x}'",
self.args.user_args))
def get_cmd(self, environment, active_resources):
......@@ -58,19 +72,28 @@ class PDSHRunner(MultiNodeRunner):
for key, val in self.exports.items():
exports += "export {}={}; ".format(key, val)
# https://linux.die.net/man/1/pdsh
# %n will be replaced by pdsh command
deepspeed_launch = [
exports,
"cd {};".format(os.path.abspath('.')),
f"cd {os.path.abspath('.')};",
sys.executable,
"-u",
"-m",
"deepspeed.launcher.launch",
'--world_info={}'.format(self.world_info_base64),
f'--world_info={self.world_info_base64}',
"--node_rank=%n",
"--master_addr={}".format(self.args.master_addr),
"--master_port={}".format(self.args.master_port)
f"--master_addr={self.args.master_addr}",
f"--master_port={self.args.master_port}"
]
if self.args.no_python:
deepspeed_launch.append("--no_python")
if self.args.module:
deepspeed_launch.append("--module")
if self.args.no_local_rank:
deepspeed_launch.append("--no_local_rank")
if self.args.save_pid:
deepspeed_launch += ["--save_pid", f"{os.getpid()}"]
return pdsh_cmd_args + deepspeed_launch + [self.user_script
] + self.user_arguments
......@@ -85,10 +108,21 @@ class OpenMPIRunner(MultiNodeRunner):
#TODO: if IB is available we should suggestion mvapich
return shutil.which('ompi_info')
def get_cmd(self, environment, active_resources):
@property
def name(self):
return "openmpi"
def validate_args(self):
super().validate_args()
#TODO: Allow for include/exclude at node-level but not gpu-level
assert self.args.include == "" and self.args.exclude == "", 'openmpi backend does not support worker include/exclusion'
assert self.args.num_nodes == -1 and self.args.num_gpus == -1, 'openmpi backend does not support limiting num nodes/gpus'
if self.args.include != "" or self.args.exclude != "":
raise ValueError(
f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(
f"{self.name} backend does not support limiting num nodes/gpus")
def get_cmd(self, environment, active_resources):
total_process_count = sum(self.resource_pool.values())
mpirun_cmd = [
......@@ -107,9 +141,13 @@ class OpenMPIRunner(MultiNodeRunner):
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', f'{k}={v}']
export_cmd += ['-x', "{}={}".format(k, v)]
python_exec = []
if not self.args.no_python:
python_exec = [sys.executable, "-u"]
if self.args.module:
python_exec.append("-m")
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
......@@ -156,14 +194,26 @@ class MVAPICHRunner(MultiNodeRunner):
)
return exists
def get_cmd(self, environment, active_resources):
@property
def name(self):
return "mvapich"
def validate_args(self):
super().validate_args()
#TODO: Allow for include/exclude at node-level but not gpu-level
assert self.args.include == "" and self.args.exclude == "", 'mvapich backend does not support worker include/exclusion'
assert self.args.num_nodes == -1 and self.args.num_gpus == -1, 'mvapich backend does not support limiting num nodes/gpus'
if self.args.include != "" or self.args.exclude != "":
raise ValueError(
f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(
f"{self.name} backend does not support limiting num nodes/gpus")
def get_cmd(self, environment, active_resources):
devices_per_node = self.resource_pool.values()
total_process_count = sum(devices_per_node)
process_per_node = list(devices_per_node)[0]
assert all([n == process_per_node for n in devices_per_node]), "mvapich requires same number of devices per node"
if not all([n == process_per_node for n in devices_per_node]):
raise ValueError("mvapich requires same number of devices per node")
with open(MVAPICH_TMP_HOSTFILE, 'w') as fd:
for host in self.resource_pool.keys():
......@@ -181,9 +231,13 @@ class MVAPICHRunner(MultiNodeRunner):
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-env', f'{k}={v}']
export_cmd += ['-env', "{}={}".format(k, v)]
python_exec = []
if not self.args.no_python:
python_exec = [sys.executable, "-u"]
if self.args.module:
python_exec.append("-m")
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
......@@ -2,14 +2,13 @@
"""
DeepSpeed runner is the main front-end to launching multi-worker
training jobs with DeepSpeed. By default this uses pdsh to parallel
ssh into multiple worker nodes and launch all the neccisary processes
ssh into multiple worker nodes and launch all the necessary processes
per rank for training.
"""
import os
import sys
import json
import shutil
import base64
import argparse
import subprocess
......@@ -23,8 +22,10 @@ from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..utils import logger
from ..autotuning import Autotuner
DLTS_HOSTFILE = "/job/hostfile"
EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", 'UCX']
EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", "UCX"]
DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env"
DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.']
PDSH_MAX_FAN_OUT = 1024
......@@ -95,7 +96,7 @@ def parse_args(args=None):
parser.add_argument("--launcher",
default=PDSH_LAUNCHER,
type=str,
help="(optional) choose launcher backend for multi-node"
help="(optional) choose launcher backend for multi-node "
"training. Options currently include PDSH, OpenMPI, MVAPICH.")
parser.add_argument("--launcher_args",
......@@ -104,6 +105,43 @@ def parse_args(args=None):
help="(optional) pass launcher specific arguments as a "
"single quoted argument.")
parser.add_argument("--module",
action="store_true",
help="Change each process to interpret the launch "
"script as a Python module, executing with the same "
"behavior as 'python -m'.")
parser.add_argument("--no_python",
action="store_true",
help="Skip prepending the training script with "
"'python' - just execute it directly.")
parser.add_argument("--no_local_rank",
action="store_true",
help="Do not pass local_rank as an argument when calling "
"the user's training script.")
parser.add_argument("--force_multi",
action="store_true",
help="Force multi-node launcher mode, helps in cases where user "
"wants to launch on single remote node.")
parser.add_argument(
"--save_pid",
action="store_true",
help="Save file containing launcher process id (pid) at /tmp/<main-pid>.ds, "
"where <main-pid> is the pid of the first process that invoked `deepspeed`. "
"Useful when launching deepspeed processes programmatically.")
parser.add_argument(
"--autotuning",
default="",
choices=["tune",
"run"],
type=str,
help="Run DeepSpeed autotuner to discover optimal configuration parameters "
"before running job.")
parser.add_argument("user_script",
type=str,
help="User script to launch, followed by any required "
......@@ -137,12 +175,22 @@ def fetch_hostfile(hostfile_path):
if hostname in resource_pool:
logger.error("Hostfile contains duplicate hosts, unable to "
"proceed with training.")
raise ValueError("host {} is already defined".format(hostname))
raise ValueError(f"host {hostname} is already defined")
resource_pool[hostname] = slot_count
return resource_pool
def _stable_remove_duplicates(data):
# Create a new list in the same order as original but with duplicates
# removed, should never be more than ~16 elements so simple is best
new_list = []
for x in data:
if x not in new_list:
new_list.append(x)
return new_list
def parse_resource_filter(host_info, include_str="", exclude_str=""):
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
......@@ -187,27 +235,25 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""):
# sanity checks
if hostname not in host_info:
raise ValueError("Hostname '{}' not found in hostfile".format(hostname))
for s in slots:
if s not in host_info[hostname]:
raise ValueError("No slot '{}' specified on host '{}'".format(
s,
hostname))
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
for slot in slots:
if slot not in host_info[hostname]:
raise ValueError(f"No slot '{slot}' specified on host '{hostname}'")
# If include string, build the list from here
if include_str:
filtered_hosts[hostname] = slots
elif exclude_str:
for s in slots:
logger.info('removing {} from {}'.format(s, hostname))
filtered_hosts[hostname].remove(s)
for slot in slots:
logger.info(f'removing {slot} from {hostname}')
filtered_hosts[hostname].remove(slot)
# User just specified the whole node
else:
hostname = node_config
# sanity check hostname
if hostname not in host_info:
raise ValueError("Hostname '{}' not found in hostfile".format(hostname))
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
if include_str:
filtered_hosts[hostname] = host_info[hostname]
......@@ -218,7 +264,7 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""):
del_keys = []
for hostname in filtered_hosts:
# Remove duplicates
filtered_hosts[hostname] = list(set(filtered_hosts[hostname]))
filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname])
# Remove empty hosts
if len(filtered_hosts[hostname]) == 0:
del_keys.append(hostname)
......@@ -251,15 +297,43 @@ def encode_world_info(world_info):
return world_info_base64
def run_autotuning(args, active_resources):
tuner = Autotuner(args, active_resources)
logger.info("[Start] Running autotuning")
tuner.tune()
tuner.print_tuning_results()
logger.info("[End] Running autotuning")
if args.autotuning == "run":
tuner.run_after_tuning()
def main(args=None):
args = parse_args(args)
resource_pool = fetch_hostfile(args.hostfile)
# respect CUDA_VISIBLE_DEVICES for a single node and no explicit resource filters
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if not resource_pool and len(cuda_visible_devices):
detected_str = f"Detected CUDA_VISIBLE_DEVICES={cuda_visible_devices}"
if len(args.include) or len(
args.exclude) or args.num_nodes > 1 or args.num_gpus > 0:
print(
f"{detected_str} but ignoring it because one or several of --include/--exclude/--num_gpus/--num_nodes cl args were used. If you want to use CUDA_VISIBLE_DEVICES don't pass any of these arguments to deepspeed."
)
else:
args.include = f"localhost:{cuda_visible_devices}"
print(f"{detected_str}: setting --include={args.include}")
del os.environ["CUDA_VISIBLE_DEVICES"]
if args.num_nodes >= 0 or args.num_gpus >= 0:
if args.include != "" or args.exclude != "":
raise ValueError("Cannot specify num_nodes/gpus with include/exclude")
multi_node_exec = True
resource_pool = fetch_hostfile(args.hostfile)
if not resource_pool:
resource_pool = {}
device_count = torch.cuda.device_count()
......@@ -275,17 +349,33 @@ def main(args=None):
active_resources = parse_inclusion_exclusion(resource_pool,
args.include,
args.exclude)
env = os.environ.copy()
# validate that passwordless-ssh is workly properly with this hostfile
if multi_node_exec:
first_host = list(active_resources.keys())[0]
try:
subprocess.check_call(
f'ssh -o PasswordAuthentication=no {first_host} hostname',
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
shell=True)
except subprocess.CalledProcessError:
raise RuntimeError(
f"Using hostfile at {args.hostfile} but host={first_host} was not reachable via ssh. If you are running with a single node please remove {args.hostfile} or setup passwordless ssh."
)
if not args.master_addr:
assert multi_node_exec
first_host = list(active_resources.keys())[0]
hostname_cmd = ["ssh {} hostname -I".format(first_host)]
hostname_cmd = [f"ssh {first_host} hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
args.master_addr = result.decode('utf-8').split()[0]
logger.info("Using IP address of {} for node {}".format(
args.master_addr,
first_host))
logger.info(f"Using IP address of {args.master_addr} for node {first_host}")
if args.autotuning != "":
run_autotuning(args, active_resources)
return
if args.num_nodes > 0:
updated_active_resources = collections.OrderedDict()
......@@ -304,10 +394,7 @@ def main(args=None):
# encode world info as base64 to make it easier to pass via command line
world_info_base64 = encode_world_info(active_resources)
multi_node_exec = len(active_resources) > 1
if multi_node_exec and not shutil.which('pdsh'):
raise RuntimeError("pdsh is not installed, unable to proceed")
multi_node_exec = args.force_multi or len(active_resources) > 1
if not multi_node_exec:
deepspeed_launch = [
......@@ -315,10 +402,18 @@ def main(args=None):
"-u",
"-m",
"deepspeed.launcher.launch",
"--world_info={}".format(world_info_base64),
"--master_addr={}".format(args.master_addr),
"--master_port={}".format(args.master_port)
f"--world_info={world_info_base64}",
f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"
]
if args.no_python:
deepspeed_launch.append("--no_python")
if args.module:
deepspeed_launch.append("--module")
if args.no_local_rank:
deepspeed_launch.append("--no_local_rank")
if args.save_pid:
deepspeed_launch += ["--save_pid", f"{os.getpid()}"]
cmd = deepspeed_launch + [args.user_script] + args.user_args
else:
args.launcher = args.launcher.lower()
......@@ -350,13 +445,14 @@ def main(args=None):
if os.path.isfile(environ_file):
with open(environ_file, 'r') as fd:
for var in fd.readlines():
key, val = var.split('=')
key, val = var.split('=', maxsplit=1)
runner.add_export(key, val)
cmd = runner.get_cmd(env, active_resources)
logger.info("cmd = {}".format(' '.join(cmd)))
logger.info(f"cmd = {' '.join(cmd)}")
result = subprocess.Popen(cmd, env=env)
result.wait()
# In case of failure must propagate the error-condition back to the caller (usually shell). The
......
from .replace_module import replace_transformer_layer
from .replace_module import replace_transformer_layer, revert_transformer_layer
from .module_quantize import quantize_transformer_layer
from .replace_policy import DSPolicy, HFBertLayerPolicy
File mode changed from 100755 to 100644
import copy
import torch
import deepspeed
def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=False):
""" Quantize bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
megatron (bool): megatron model-parallel implementation (this is supported for inference only)
preln (bool): does the original layer implementation do pre or post layer norm?
Note: For Bert kind of models, we inject based on the DeepSpeed-Example models, if not setting huggingface flag.
Returns:
Updated nn.module with quantized transformer layers
"""
def quantize_weight(weight):
return weight.to(torch.int8)
def megatron_layer_quantize(layer):
layer.attention.query_key_value.weight.data = quantize_weight(
layer.attention.query_key_value.weight.data)
layer.attention.dense.weight.data = quantize_weight(
layer.attention.dense.weight.data)
layer.mlp.dense_h_to_4h.weight.data = quantize_weight(
layer.mlp.dense_h_to_4h.weight.data)
layer.mlp.dense_4h_to_h.weight.data = quantize_weight(
layer.mlp.dense_4h_to_h.weight.data)
def bert_layer_quantize(layer):
layer.attention.self.query.weight.data = quantize_weight(
layer.attention.self.query.weight.data)
layer.attention.self.key.weight.data = quantize_weight(
layer.attention.self.key.weight.data)
layer.attention.self.value.weight.data = quantize_weight(
layer.attention.self.value.weight.data)
layer.attention.output.dense.weight.data = quantize_weight(
layer.attention.output.dense.weight.data)
if preln:
layer.intermediate.dense_act.weight.data = quantize_weight(
layer.intermediate.dense_act.weight.data)
else:
layer.intermediate.dense.weight.data = quantize_weight(
layer.intermediate.dense.weight.data)
layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data)
def quantize_fn(child):
if megatron:
# Quantize megatron GPT2 / GPT3 trained model
megatron_layer_quantize(child)
else:
# Quantize either DeepSpeed or HuggingFace trained model
bert_layer_quantize(child)
return child
return quantize_module(model=model,
orig_class=orig_layer_impl,
quantize_fn=quantize_fn)
def quantize_module(model, orig_class, quantize_fn):
policy = {orig_class: quantize_fn}
return _quantize_module(model, policy)
def _quantize_module(model, policies):
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
else:
_quantize_module(child, policies)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment