Commit 7bc5a8e3 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents e6748d82 0f785cb1
import time
from typing import List, Dict, Type
from abc import ABC, abstractmethod
NOT_NVML = False
try:
from pynvml import *
except:
NOT_NVML = True
import torch
from torch.fx.node import Node
from colossalai.utils.cuda import get_current_device
from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
from .region import Region
from .util import NodeInfo, NvDevicePower
def benchmark_func(func, number=1, repeat=1, warmup=3):
"""
benchmark data transfer cost.
"""
for i in range(warmup):
func()
costs = []
for i in range(repeat):
torch.cuda.synchronize()
begin = time.time()
for i in range(number):
func()
torch.cuda.synchronize()
costs.append((time.time() - begin) / number)
return sum(costs) / len(costs)
class Solver(ABC):
"""
The parameter offload solver.
Args:
region_list (List[Region]): represents the linearized DNN computing graph.
memory_budget (float): the given memory budget.
error_factor (float): the error factor.
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
"""
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0,
error_factor: float = 0.95) -> None:
self.region_list = region_list
self.error_factor: float = error_factor
if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor
else:
self.memory_budget = torch.cuda.get_device_properties(
get_current_device()).total_memory * self.error_factor
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power()
@abstractmethod
def _call_solver(self):
raise NotImplementedError
@abstractmethod
def _try_to_offload(self, *args):
raise NotImplementedError
@abstractmethod
def _eval_one_choice(self, *args):
raise NotImplementedError
def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: float, extra_cost: float):
"""
Compute the profits of the offload strategies,
which packages the memory savings information for subsequent comparisons.
Args:
total_mem_saving (float): the total memory saving of the offload strategy.
peak_mem_saving (float): the peak memory saving of the offload strategy.
extra_cost (float): extra data transfer cost.
Returns:
tuple: profit information, the first term represents memory savings per unit of time.
"""
if extra_cost == 0:
# means data transfer overhead can be completely overlapped
return (float('inf'), total_mem_saving, peak_mem_saving)
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
"""
Compare the profits of the two offload strategies using the dictionary order algorithm.
Args:
profit_a (tuple): the profit of a offload strategy.
profit_b (tuple): the profit of another offload strategy.
Returns:
bool: whether profit_a is greater than profit_b.
"""
for val1, val2 in zip(profit_a, profit_b):
if val1 != val2:
return val1 > val2
return False
def _update_state(self, best_ts: TrainingSimulator):
"""
Update the solver state.
"""
self.best_ts = best_ts
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
def _update_node_mem_info(self,
fwd_mem_info: Dict[Node, float],
bwd_mem_info: Dict[Node, float]):
"""
Update the runtime memory information of the node.
Args:
fwd_mem_info (Dict[Node, float]): the runtime memory of each node in forward pass.
bwd_mem_info (Dict[Node, float]): the runtime memory of each node in backward pass.
"""
for node, mem in fwd_mem_info.items():
assert hasattr(node, 'node_info') and isinstance(
node.node_info, NodeInfo)
node.node_info.runtime_fwd_mem = mem
for node, mem in bwd_mem_info.items():
assert hasattr(node, 'node_info') and isinstance(
node.node_info, NodeInfo)
node.node_info.runtime_bwd_mem = mem
def _extract_computing_power(self):
"""
return the FP16 computing performance of the current NVIDIA GPU.
Raises:
TypeError: Unknown NVIDIA GPU device.
"""
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
device_name = nvmlDeviceGetName(handle)
units = 1e12
if device_name.__contains__("RTX 3080"):
return NvDevicePower.RTX3080_FP16 * units
elif device_name.__contains__("RTX 3090"):
return NvDevicePower.RTX3090_FP16 * units
elif device_name.__contains__('V100'):
return NvDevicePower.V100_FP16 * units
elif device_name.__contains__("A100"):
return NvDevicePower.A100_FP16 * units
else:
raise TypeError(f'Unknown NVIDIA GPU device name {device_name}')
def _profile_bandwidth(self):
"""
Profile the bidirectional communication bandwidth between CPU and GPU
using data volumes ranging from 1KB to 1GB.
"""
print('profiling bandwidth ......')
link_to_bandwidth = {}
links = ['h2d', 'd2h']
for link in links:
t_size = 1024
size_to_bandwidth = {}
# from 1KB to 1GB
for i in range(21):
if link == 'h2d':
src_tensor = torch.ones(
int(t_size), dtype=torch.int8, pin_memory=True)
dst_tensor = torch.ones(
(int(t_size)), dtype=torch.int8, device='cuda')
elif link == 'd2h':
src_tensor = torch.ones(
int(t_size), dtype=torch.int8, device='cuda')
dst_tensor = torch.ones(
(int(t_size)), dtype=torch.int8, pin_memory=True)
def func():
dst_tensor.copy_(src_tensor)
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
print(f'size: {t_size / 1024 ** 2:.3f} MB, '
f'{src_tensor.device.type}-to-{dst_tensor.device.type} '
f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s')
t_size *= 2
link_to_bandwidth[link] = size_to_bandwidth
return link_to_bandwidth
class SynGreedySolver(Solver):
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0) -> None:
super().__init__(region_list, memory_budget)
self.best_ts: SynTrainingSimulator = None
self._init_state()
def _init_state(self):
"""
Initialize the solver state when without offloading.
"""
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
self._update_state(ts)
def _call_solver(self):
"""
Call the solver to search an efficient parameter offloading strategy for the linearized graph.
The solver adopts greedy algorithm.
Raises:
NotImplementedError: Unable to find a solution for the given memory budget.
"""
print("search offloading strategy ......")
while self.best_ts.peak_mem > self.memory_budget:
offload_region = None
best_ts = None
max_profit = (0,)
# search which region should be offloaded,
# the last region does not need to be offloaded.
for region in self.region_list[:-1]:
if region.param_size and not region.need_offload:
temp_ts, profit = self._try_to_offload(region)
if self._compare_profit(profit, max_profit):
offload_region = region
max_profit = profit
best_ts = temp_ts
if offload_region is not None and best_ts is not None:
offload_region.need_offload = True
offload_region.is_syn = True
self._update_state(best_ts)
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
def _call_solver_l2l(self):
"""
The layer-wise offload strategy.
"""
for region in self.region_list[:-1]:
region.need_offload = True
region.is_syn = True
def _try_to_offload(self, offload_region: Region):
# record previous information
orig_need_offload = offload_region.need_offload
assert not orig_need_offload
offload_region.need_offload = True
ts, profit = self._eval_one_choice(offload_region)
# restore previous information
offload_region.need_offload = orig_need_offload
return ts, profit
def _eval_one_choice(self, offload_region: Region):
"""
Evaluate the profit of a strategy choice.
Args:
offload_region (Region): the offload region of current choice.
Returns:
SynTrainingSimulator: the training simulator corresponding to the current strategy.
tuple: contains memory saving and cost information of the current strategy.
"""
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
extra_comm_cost = 2.0 * \
ts._get_communication_overhead('h2d', offload_region.param_size)
# the shared region needs to be moved twice
if offload_region.r_id < offload_region.shared_rid:
extra_comm_cost *= 2.0
profit = self._compute_offload_profit(
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class AsynGreedySolver(Solver):
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0,
search_window_size: int = 3):
super().__init__(region_list, memory_budget)
self.search_window_size = search_window_size
# Records the prefetch execution location of the offloaded region
self.region_to_region_map = {}
self.best_ts: AsynTrainingSimulator = None
self._init_state()
def _init_state(self):
"""
Initialize the solver state when without offloading.
"""
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
self._update_state(ts)
print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB")
def _call_solver(self):
"""
Call the solver to search an efficient parameter offloading strategy for the linearized graph.
The solver adopts greedy algorithm.
Raises:
NotImplementedError: Unable to find a solution for the given memory budget.
"""
print("search for offloading strategy ......")
# Records the prefetch execution location of the offloaded region
region_to_region_map = {}
while self.best_ts.peak_mem > self.memory_budget:
region_to_offload = None
max_offload_profit = (0,)
best_offl_ts = None
# search which region should be offloaded,
# the last region does not need to be offloaded
for region in self.region_list[:-1]:
if region.param_size and not region.need_offload:
max_prefetch_profit = (0,)
best_pref_ts = None
# search when to prefetch the region offloaded
for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]:
if host_region.bwd_prefetch_region is not None:
continue
temp_ts, profit = self._try_to_offload(
host_region, region)
if self._compare_profit(profit, max_prefetch_profit):
region_to_region_map[region.r_id] = host_region
max_prefetch_profit = profit
best_pref_ts = temp_ts
if profit[0] == float('inf'):
break
if self._compare_profit(max_prefetch_profit, max_offload_profit):
region_to_offload = region
max_offload_profit = max_prefetch_profit
best_offl_ts = best_pref_ts
if (region_to_offload is not None) and (best_offl_ts is not None):
region_to_offload.need_offload = True
if region_to_region_map[region_to_offload.r_id] == region_to_offload:
region_to_offload.is_syn = True
else:
region_to_region_map[region_to_offload.r_id].bwd_prefetch_region = region_to_offload
self.region_to_region_map[region_to_offload.r_id] = region_to_region_map[region_to_offload.r_id]
self._update_state(best_offl_ts)
elif self.region_to_region_map.__len__() > 0:
self._repair_strategy()
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
region_to_region_map.clear()
def _try_to_offload(self, host_region: Region, offload_region: Region):
"""
Attempts to offload the region and prefetch it in backward pass.
"""
# record previous information
orig_prefetch = host_region.bwd_prefetch_region
orig_is_syn = offload_region.is_syn
orig_need_offload = offload_region.need_offload
if host_region == offload_region:
offload_region.is_syn = True
else:
host_region.bwd_prefetch_region = offload_region
offload_region.need_offload = True
ts, profit = self._eval_one_choice()
# restore previous information
host_region.bwd_prefetch_region = orig_prefetch
offload_region.is_syn = orig_is_syn
offload_region.need_offload = orig_need_offload
return ts, profit
def _try_convert_to_syn_upload(self, host_region: Region, offload_region: Region):
"""
Attempts to convert asynchronous prefetch into synchronous upload operations.
"""
# record previous information
orig_prefetch = host_region.bwd_prefetch_region
orig_is_syn = offload_region.is_syn
assert orig_prefetch is not None and not orig_is_syn
host_region.bwd_prefetch_region = None
offload_region.is_syn = True
ts, profit = self._eval_one_choice()
# restore previous information
host_region.bwd_prefetch_region = orig_prefetch
offload_region.is_syn = orig_is_syn
return ts, profit
def _repair_strategy(self):
"""
Repair offload strategy.
It attempts to convert asynchronous prefetch into synchronous upload operations and selects the best one.
The repair process does not end until peak memory is reduced or there is no asynchronous prefetch operation.
"""
print("repair strategy ......")
peak_mem_saving = 0
while len(self.region_to_region_map) and peak_mem_saving <= 0:
max_profit = (0,)
best_ts = None
undo_host_region = None
undo_offload_region = None
for offload_region_id, host_region in self.region_to_region_map.items():
offload_region = self.region_list[offload_region_id]
assert host_region.bwd_prefetch_region == offload_region
assert offload_region.need_offload
assert not offload_region.is_syn
ts, profit = self._try_convert_to_syn_upload(host_region,
offload_region)
if self._compare_profit(profit, max_profit):
undo_host_region = host_region
undo_offload_region = offload_region
max_profit = profit
best_ts = ts
if best_ts is None:
raise NotImplementedError('repair error!')
assert not undo_offload_region.is_syn
undo_offload_region.is_syn = True
undo_host_region.bwd_prefetch_region = None
peak_mem_saving = self.best_ts.peak_mem - best_ts.peak_mem
self._update_state(best_ts)
self.region_to_region_map.pop(undo_offload_region.r_id)
return best_ts
def _eval_one_choice(self):
"""
Evaluate the profit of a strategy choice.
Returns:
AsynTrainingSimulator: the training simulator corresponding to the current strategy.
tuple: contains memory saving and cost information of the current strategy.
"""
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
profit = self._compute_offload_profit(
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class SolverFactory:
solvers: Dict[str, Type[Solver]] = {
'syn': SynGreedySolver,
'asyn': AsynGreedySolver
}
@staticmethod
def create(solver_name: str) -> Type[Solver]:
if solver_name not in SolverFactory.solvers:
raise TypeError(f"Unknown parameter offload policy {solver_name}")
return SolverFactory.solvers[solver_name]
@staticmethod
def get_solver_names():
return tuple(SolverFactory.solvers.keys())
import bisect
from typing import List, Dict
from collections import OrderedDict
from abc import ABC, abstractmethod
from torch.fx.node import Node
from .region import Region
from .util import *
@dataclass
class ExecutionPeriod:
start_time: float = 0
end_time: float = 0
class TrainingSimulator(ABC):
"""
The Training Simulator is used to simulate the training process.
It records computation, communication, and runtime memory during forward and backward passes.
Args:
region_list (List[Region]): represents the linearized DNN computing graph.
comp_power (float): the NVIDIA GPU FP16 computing power.
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
"""
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
self.region_list = region_list
self.region_num = len(region_list)
self.runtime_mem: int = 0
self.peak_mem: int = 0
self.total_mem_saving: int = 0
self.fwd_node_mem: Dict[Node, float] = {}
self.bwd_node_mem: Dict[Node, float] = {}
# Node dependencies in backward pass
self.bwd_node_deps: Dict[Node, int] = {}
self.comp_power: float = comp_power
self.link_to_bandwidth: Dict[str, Dict[float, float]] = link_to_bw
@abstractmethod
def execute(self):
raise NotImplementedError
@abstractmethod
def _eval_fwd_mem_per_region(self, region: Region):
raise NotImplementedError
@abstractmethod
def _eval_bwd_mem_per_region(self, region: Region):
raise NotImplementedError
def _get_bandwidth(self, link: str, comm_volumn: float) -> float:
"""
Get the data transfer bandwidth.
Args:
link (str): the data transfer link.
comm_volumn (float): the amount of data transferred.
Returns:
float: the data transfer bandwidth.
"""
assert len(self.link_to_bandwidth)
if link not in self.link_to_bandwidth:
raise TypeError(f"Unknown data transfer link {link}")
# size_list = sorted(list(map(float, self.link_to_bandwidth[link].keys())))
size_list = sorted(self.link_to_bandwidth[link].keys())
d_idx = bisect.bisect_left(size_list, comm_volumn)
return self.link_to_bandwidth[link][size_list[d_idx]]
def _get_communication_overhead(self, link: str, comm_volumn: float) -> float:
return comm_volumn / self._get_bandwidth(link, comm_volumn)
def _get_computing_overhead(self, flop: float) -> float:
return flop / self.comp_power
class SynTrainingSimulator(TrainingSimulator):
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
def execute(self):
"""
Simulate synchronous training process.
"""
for reg in self.region_list:
self._eval_fwd_mem_per_region(reg)
for reg in self.region_list.__reversed__():
self._eval_bwd_mem_per_region(reg)
def _eval_fwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the forward execution reaches the current region.
"""
# upload parameters of the current region
if requires_upload_p_in_fwd(self.region_list[region.shared_rid]):
self.runtime_mem += region.param_size
for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + \
calculate_fwd_out(node)
self.fwd_node_mem[node] = self.runtime_mem
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
if region.need_offload:
self.runtime_mem -= region.param_size
def _eval_bwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the backward execution reaches the current region.
"""
# upload parameters of the current region
if region.need_offload:
self.runtime_mem += region.param_size
# add the gradient of the parameter
if region.r_id < region.shared_rid:
# gradient accumulation is required for shared parameters
self.runtime_mem += 2.0 * region.param_size
else:
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta['bwd_mem_tmp'] + \
node.meta['bwd_mem_out']
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
calculate_fwd_tmp(node))
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
for user_node in node.users:
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta['bwd_mem_out']
if self.runtime_mem < 0:
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!")
# release parameter and offload gradient in region
if region.r_id == region.shared_rid:
self.runtime_mem -= 2.0 * region.param_size
elif region.r_id < region.shared_rid:
self.runtime_mem -= 3.0 * region.param_size
elif self.region_list[region.shared_rid].need_offload:
self.runtime_mem -= region.param_size
class AsynTrainingSimulator(TrainingSimulator):
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
self.iter_end_time: int = 0
# the last computation execution period
self.last_comp: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
# the last parameter prefetch execution period
self.last_h2d: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
# the last gradient offload execution period
self.last_d2h: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
# the forward computation execution period of the region
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the forward parameter prefetch execution period of the region
self.fwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the backward computation execution period of the region
self.bwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the backward parameter prefetch execution period of the region
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released
self.bwd_reg_to_offl_waiting: OrderedDict[int,
ExecutionPeriod] = OrderedDict()
self.bwd_reg_to_offl_freed: OrderedDict[int,
ExecutionPeriod] = OrderedDict()
# the region buffer, which records regions that are offloaded but not released
self.reg_buffer_to_free: List[int] = []
# node dependencies in backward pass
self.bwd_node_deps: Dict[Node, int] = {}
# the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region.
self.fwd_reg_flow = torch.zeros(
(self.region_num, self.region_num)).bool()
self.bwd_reg_flow = torch.zeros(
(self.region_num, self.region_num)).bool()
def execute(self):
"""
Simulate asynchronous training process.
In forward pass, parameter prefetching is advanced by one region.
In backward pass, parameter prefetching is executed at the specified location,
and gradient offloading is urgent.
"""
for reg in self.region_list:
if reg.param_size and reg.r_id < self.region_num - 1:
for nr in self.region_list[reg.r_id + 1:]:
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
reg.fwd_prefetch_region = nr
break
self._eval_fwd_cost_per_region(reg)
self._eval_fwd_mem_per_region(reg)
for reg in self.region_list.__reversed__():
self._eval_bwd_cost_per_region(reg)
self._eval_bwd_mem_per_region(reg)
# release remaining grads
for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items():
self.bwd_reg_to_offl_freed[reg_id] = offl_exec
self.runtime_mem -= self.region_list[reg_id].param_size
self.bwd_reg_to_offl_waiting.clear()
self.iter_end_time = max(
self.last_comp.end_time, self.last_d2h.end_time)
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
"""
Insert parameter prefetch execution period of the current region to the end of the h2d stream
"""
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
pref_end_time = pref_start_time + \
2.0 * self._get_communication_overhead('h2d', region.param_size)
pref_ep = ExecutionPeriod(
start_time=pref_start_time, end_time=pref_end_time)
if is_fwd:
self.fwd_reg_to_pref[region.r_id] = pref_ep
else:
self.bwd_reg_to_pref[region.r_id] = pref_ep
self.last_h2d = pref_ep
def _insert_comp_exec(self, region: Region, is_fwd: bool = True):
"""
Insert computation execution period of the current region to the end of the computing stream
"""
if is_fwd:
reg_to_comp = self.fwd_reg_to_comp
reg_to_pref = self.fwd_reg_to_pref
flop_key = 'fwd_flop'
else:
reg_to_comp = self.bwd_reg_to_comp
reg_to_pref = self.bwd_reg_to_pref
flop_key = 'bwd_flop'
comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(
region.r_id, ExecutionPeriod(0, 0)).end_time)
comp_end_time = comp_start_time + \
sum([self._get_computing_overhead(node.meta.get(flop_key, 0))
for node in region.nodes])
comp_ep = ExecutionPeriod(
start_time=comp_start_time, end_time=comp_end_time)
reg_to_comp[region.r_id] = comp_ep
self.last_comp = comp_ep
def _insert_d2h_exec(self, region: Region):
"""
Insert gradient offload execution period of the current region to the end of the d2h stream
"""
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
offl_end_time = offl_start_time + \
self._get_communication_overhead('d2h', region.param_size)
offl_ep = ExecutionPeriod(
start_time=offl_start_time, end_time=offl_end_time)
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
self.last_d2h = offl_ep
def _eval_fwd_cost_per_region(self, region: Region):
"""
Evaluate computation and communication execution period of the region in forward pass.
"""
# upload parameters of the first region
if region.r_id == 0:
self._insert_h2d_exec(region)
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self._insert_h2d_exec(fwd_prefetch_region)
# execute computation
self._insert_comp_exec(region)
def _eval_fwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the forward execution reaches the current region.
"""
# upload parameters of the current region
if region.r_id <= 0:
self.runtime_mem += region.param_size
self.fwd_reg_flow[region.r_id, region.r_id] = True
else:
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
self.fwd_reg_flow[region.r_id,
self.reg_buffer_to_free] = False
self.reg_buffer_to_free.clear()
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self.runtime_mem += fwd_prefetch_region.param_size
self.fwd_reg_flow[region.r_id,
fwd_prefetch_region.r_id] = True
for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + \
calculate_fwd_out(node)
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
self.fwd_node_mem[node] = self.runtime_mem
if region.need_offload:
self.runtime_mem -= region.param_size
assert len(
self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}'
self.reg_buffer_to_free.append(region.r_id)
def _eval_bwd_cost_per_region(self, region: Region):
"""
Evaluate computation and communication execution period of the region in backward pass.
"""
# upload parameters of the current region
if region.is_syn:
assert region.need_offload
self._insert_h2d_exec(region, is_fwd=False)
# prefetch parameters of the region choiced, which is parallel to computation
if region.bwd_prefetch_region is not None:
self._insert_h2d_exec(region.bwd_prefetch_region, is_fwd=False)
# execute computation
self._insert_comp_exec(region, is_fwd=False)
# offload gradient
if requires_offload_g_in_bwd(region):
self._insert_d2h_exec(region)
assert len(self.reg_buffer_to_free) == 0
for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items():
if offl_exec.end_time >= self.last_comp.start_time:
break
self.reg_buffer_to_free.append(reg_id)
self.bwd_reg_to_offl_freed[reg_id] = offl_exec
for reg_id in self.reg_buffer_to_free:
self.bwd_reg_to_offl_waiting.pop(reg_id)
def _eval_bwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the backward execution reaches the current region.
"""
if region.r_id + 1 < self.region_num:
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
else:
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
self.bwd_reg_flow[region.r_id,
self.reg_buffer_to_free] = False
# free gradients in the buffer
while len(self.reg_buffer_to_free):
reg_id = self.reg_buffer_to_free.pop(0)
self.runtime_mem -= self.region_list[reg_id].param_size
# upload parameters of the current region
if region.is_syn:
self.runtime_mem += region.param_size
self.bwd_reg_flow[region.r_id, region.r_id] = True
# prefetch parameters of the region choiced
bwd_prefetch_region = region.bwd_prefetch_region
if bwd_prefetch_region:
self.runtime_mem += bwd_prefetch_region.param_size
self.bwd_reg_flow[region.r_id,
bwd_prefetch_region.r_id] = True
# add the gradient of the parameter
if region.r_id < region.shared_rid:
# gradient accumulation is required for shared parameters
self.runtime_mem += 2.0 * region.param_size
else:
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta['bwd_mem_tmp'] + \
node.meta['bwd_mem_out']
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
calculate_fwd_tmp(node))
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
for user_node in node.users:
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta['bwd_mem_out']
if self.runtime_mem < 0:
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!")
# release parameters of the region
if requires_release_p_in_bwd(self.region_list[region.shared_rid]):
self.runtime_mem -= region.param_size
from dataclasses import dataclass
from typing import List
import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp
from .region import Region
@dataclass
class NodeInfo:
node_id: int = 0
runtime_fwd_mem: float = 0
runtime_bwd_mem: float = 0
class NvDevicePower:
"""
NVIDIA GPU computing performance (TFLOPs).
"""
RTX3080_FP16 = 70
RTX3080_FP32 = 34.1
RTX3090_FP16 = 71
RTX3090_FP32 = 35.7
V100_FP16 = 31.4
V100_FP32 = 15.7
A100_FP16 = 78
A100_FP32 = 19.5
class GlobalRuntimeInfo(metaclass=SingletonMeta):
def __init__(self):
self.h2d_stream = torch.cuda.Stream()
self.d2h_stream = torch.cuda.Stream()
self.fwd_prefetch_event_map = {}
self.bwd_prefetch_event_map = {}
self.region_list = []
def compute_act_peak_mem(region_list: List[Region]) -> float:
act_peak_mem = 0
runtime_mem = 0
# forward
for region in region_list:
for node in region.nodes:
runtime_mem = runtime_mem + \
calculate_fwd_tmp(node) + calculate_fwd_out(node)
act_peak_mem = max(runtime_mem, act_peak_mem)
# backward
bwd_deps = {}
for region in region_list.__reversed__():
for node in region.nodes.__reversed__():
runtime_mem -= calculate_fwd_out(node)
runtime_mem = runtime_mem + \
node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out']
act_peak_mem = max(runtime_mem, act_peak_mem)
runtime_mem = runtime_mem - \
node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node)
# free bwd_mem_out
bwd_deps[node] = len(node.all_input_nodes)
for user_node in node.users:
if user_node in bwd_deps:
bwd_deps[user_node] -= 1
if bwd_deps[user_node] <= 0:
runtime_mem -= user_node.meta['bwd_mem_out']
return act_peak_mem
def compute_max_param_mem(region_list: List[Region]) -> float:
return max(region.param_size for region in region_list)
def compute_total_param_mem(region_list: List[Region]) -> float:
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)
def requires_upload_p_in_fwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
and shared_reg.need_offload)
def requires_release_p_in_bwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
and shared_reg.need_offload)
def requires_offload_g_in_bwd(region: Region):
return region.param_size and (region.r_id <= region.shared_rid)
from typing import Dict
import torch
from torch.fx import GraphModule
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
meta_info = ShardMetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for ShardMetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length
meta_info.memory_cost = mem_cost
# get computation cost for ShardMetaInfo
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
# get tensor shape for ShardMetaInfo
origin_sharding_spec: ShardingSpec
target_sharding_spec: ShardingSpec
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
return meta_info
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> ShardMetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
"""
# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
user_node_index]
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> ShardMetaInfo:
# extract node_index and op_data_name
node_index, op_data_name = node.args[2], node.args[3]
comm_action = comm_actions_dict[node_index][op_data_name]
if isinstance(comm_action.comm_spec, CommSpec):
# this case is for all_reduce, there will be no memory cost
meta_info = ShardMetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost()
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
else:
# this case will be handled by shape consistency manager
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
'tgt_spec']
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
comm_actions_dict: Dict) -> GraphModule:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
return gm
import torch
OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten]
OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]
# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args.
# This list could be extended if any other method has the same
# argument style as view and reshape.
SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape]
import uuid
from dataclasses import asdict
from typing import List
import torch
import torch.fx
from torch.fx import GraphModule
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
@compatibility(is_backward_compatible=False)
class MetaInfoProp:
def __init__(self, module: GraphModule) -> None:
self.module = module
self.func_dict = {
'placeholder': self.placeholder_handler,
'get_attr': self.get_attr_handler,
'output': self.output_handler,
'call_function': self.node_handler,
'call_module': self.node_handler,
'call_method': self.node_handler,
}
def _set_data_ptr(self, x):
"""
Set uuid to tensor
"""
if isinstance(x, torch.Tensor):
if not x.data_ptr():
data_ptr = uuid.uuid4()
x.data_ptr = lambda: data_ptr
def _is_inplace(self, node: Node):
"""
Check if the node is inplace operation.
"""
if node.op == 'call_module':
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
return False
def run(self) -> GraphModule:
"""
Run the meta information propagation pass on the module.
"""
for node in self.module.graph.nodes:
node: Node
self.func_dict[node.op](node)
@compatibility(is_backward_compatible=False)
def placeholder_handler(self, node: Node) -> None:
"""
Handle the placeholder node.
"""
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None))
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def get_attr_handler(self, node: Node) -> None:
"""
Handle the get_attr node.
"""
graph_info = GraphInfo()
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def output_handler(self, node: Node) -> None:
"""
Handle the output node.
"""
graph_info = GraphInfo()
output_tensors = []
for par in node._input_nodes:
if par.meta:
output_tensors += par.meta["fwd_out"]
graph_info.fwd_in = output_tensors
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def node_handler(self, node: Node) -> None:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_strategy_info
meta_info: ShardMetaInfo
# set data_ptr for input_tensor in ShardMetaInfo class
input_tensors: List[torch.Tensor] = meta_info.fwd_in
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
output_tensors: List[torch.Tensor] = meta_info.fwd_out
if self._is_inplace(node):
# inplace operation will not create new tensor, and it only has one parent node
# TODO: Verify this observation
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
parent_node = list(node._input_nodes.keys())[0]
parent_tensor = parent_node.meta.get("fwd_out")[0]
parent_tensor: torch.Tensor
for tensor in input_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in buffer_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in output_tensors:
tensor.data_ptr = parent_tensor.data_ptr
else:
for par in node._input_nodes:
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr
# set data_ptr for tensor in input_tensor that is not set
for tensor in input_tensors:
if not tensor.data_ptr():
self._set_data_ptr(tensor)
# set data_ptr for buffer_tensor
for tensor in buffer_tensors:
self._set_data_ptr(tensor)
# set data_ptr for output_tensor
for tensor in output_tensors:
self._set_data_ptr(tensor)
# attach them to graph_info
graph_info.fwd_in = input_tensors
graph_info.fwd_tmp = buffer_tensors
graph_info.fwd_out = output_tensors
# fetch other memory informations
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.fwd_mem_out = memory_cost.fwd.activation
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
graph_info.bwd_mem_out = memory_cost.bwd.activation
# fetch flop information
# here we use fwd_time and bwd_time to deal with the case that
# communication cost is a float
compute_cost = meta_info.compute_cost
graph_info.fwd_time = compute_cost.fwd
graph_info.bwd_time = compute_cost.bwd
node.meta = {**asdict(graph_info)}
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int):
"""
This method will be invoked during runtime to do the shape consistency, which make sure the activations is converted into
the user node expected form.
"""
origin_sharding_spec = origin_dict[node_index]
target_sharding_spec = input_dict[node_index][user_node_index]
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
user_node_index: int):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst = []
for index, (origin_sharding_spec,
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
input_dict[node_index][user_node_index])):
rst.append(
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
target_sharding_spec))
rst = type(node)(rst)
return rst
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
"""
comm_action = comm_actions_dict[node_index][op_data_name]
if isinstance(comm_action.comm_spec, CommSpec):
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
else:
origin_sharding_spec = comm_action.comm_spec['src_spec']
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
return rst
def _preprocess_graph(nodes: List[Node]):
"""
This method is used to extract all the placeholders with sharding information,
and mapping the nodes into the index of the origin graph.
"""
# mapping the node into the origin graph index
node_to_index_dict = {}
index = 0
for node in nodes:
if node.target == 'sharding_spec_convert_dict':
input_dict_node = node
continue
if node.target == 'origin_node_sharding_spec_dict':
origin_dict_node = node
continue
if node.target == 'comm_actions_dict':
comm_actions_dict_node = node
continue
if not hasattr(node, 'best_strategy'):
continue
node_to_index_dict[node] = index
index += 1
return input_dict_node, origin_dict_node, comm_actions_dict_node, node_to_index_dict
def _shape_consistency_apply(gm: torch.fx.GraphModule):
"""
This pass is used to add the shape consistency node to the origin graph.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
if not hasattr(node, 'best_strategy') or node.op == 'output':
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance(
node.target_sharding_specs,
(list,
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
total_difference = 0
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
node.target_sharding_specs[user_node_index]):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply_for_iterable_object,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
else:
assert isinstance(node.sharding_spec,
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
if hasattr(user_node.meta['info'], 'activation_checkpoint'):
MetaInfo(shape_consistency_node,
mod_dir=user_node.meta['info'].mod_dir,
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with shape_consistency_node
origin_index_args = new_args.index(node)
new_args[origin_index_args] = shape_consistency_node
user_node.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with shape_consistency_node
new_kwargs[str(node)] = shape_consistency_node
user_node.kwargs = new_kwargs
return gm
def _comm_spec_apply(gm: torch.fx.GraphModule):
"""
This pass is used to add the comm spec apply node to the origin graph.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
if not hasattr(node, 'best_strategy') or node.op == 'output':
continue
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
if op_data.type == OperationDataType.OUTPUT:
comm_object = node
elif comm_action.key_for_kwarg is not None:
comm_object = node.kwargs[comm_action.key_for_kwarg]
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
# the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node
new_kwargs = dict(node.kwargs)
new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node
node.kwargs = new_kwargs
else:
# substitute the origin node with comm_spec_apply_node
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = tuple(new_args)
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(node, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
user_list = list(node.users.keys())
for user in user_list:
if user == comm_spec_apply_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with comm_spec_apply_node
new_args[new_args.index(node)] = comm_spec_apply_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
if hasattr(node.meta['info'], 'activation_checkpoint'):
MetaInfo(comm_spec_apply_node,
mod_dir=node.meta['info'].mod_dir,
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
for node in nodes:
if not hasattr(node.meta, 'activation_checkpoint'):
from .runtime_preparation_pass import size_processing
user_act_annotation = -1
input_act_annotation = -1
for user_node in node.users.keys():
if 'activation_checkpoint' in user_node.meta:
user_act_annotation = user_node.meta['activation_checkpoint']
break
for input_node in node._input_nodes.keys():
if 'activation_checkpoint' in input_node.meta:
input_act_annotation = input_node.meta['activation_checkpoint']
break
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
node.meta['activation_checkpoint'] = user_act_annotation
return gm
def runtime_apply_pass(gm: torch.fx.GraphModule):
"""
The method manages all the passes acting on the distributed training runtime.
"""
gm = _shape_consistency_apply(gm)
gm = _comm_spec_apply(gm)
return gm
import operator
from copy import deepcopy
from typing import Dict, List, Union
import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationDataType,
ShardingStrategy,
)
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import SHAPE_ARGUMENT_OPS
shape_consistency_manager = ShapeConsistencyManager()
def size_processing(size: Union[int, torch.Size],
dim_partition_dict: Dict[int, List[int]],
device_mesh_info: Dict[int, int],
target_dim: int = None,
node_name: str = None):
"""
This method will be invoked during runtime to convert size node value depending on distributed information.
"""
if target_dim is not None:
assert isinstance(size, int)
if target_dim in dim_partition_dict:
total_shard_size = 1
for shard_dim in dim_partition_dict[target_dim]:
total_shard_size *= device_mesh_info[shard_dim]
size = size * total_shard_size
else:
size = list(size)
for dim, dim_size in enumerate(size):
if dim in dim_partition_dict:
total_shard_size = 1
for shard_dim in dim_partition_dict[dim]:
total_shard_size *= device_mesh_info[shard_dim]
size[dim] = dim_size * total_shard_size
size = torch.Size(size)
return size
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
strategies_constructor: StrategiesConstructor):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
"""
mod_graph = gm.graph
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
no_strategy_nodes = strategies_constructor.no_strategy_nodes
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
# stick the solution strategy to the corresponding node
setattr(node, 'best_strategy', strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
str(node))
# attach the corresponding metainfo if node has the attribute `strategies_info`
if hasattr(node, 'strategies_info'):
setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
# the dict to record comm actions of nodes
comm_actions_dict = {}
for index, node in enumerate(nodes):
target_sharding_specs = []
for user_node in node.strategies_vector.successor_nodes:
if user_node in no_strategy_nodes:
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(str(node.name))
else:
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
setattr(node, 'target_sharding_specs', target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
target_node = node.strategies_vector.successor_nodes[0]
node_name = str(node)
if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
node_name = str(target_node)
target_node = target_node.strategies_vector.successor_nodes[0]
user_strategy = target_node.best_strategy
op_data_in_user = user_strategy.get_op_data_by_name(node_name)
origin_pending_strategy = node.best_strategy
origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node))
new_communication_actions = {}
if op_data_in_user in user_strategy.communication_actions:
new_communication_action = user_strategy.communication_actions.pop(op_data_in_user)
new_communication_action.arg_index = 0
new_communication_actions[origin_op_data] = new_communication_action
node.best_strategy.communication_actions = new_communication_actions
comm_action_dict = {}
for op_data, comm_action in node.best_strategy.communication_actions.items():
comm_action_dict[op_data.name] = comm_action
comm_actions_dict[index] = comm_action_dict
# add above dicts into graph
for node in nodes:
if node.op != 'placeholder':
with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
break
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
need to be converted to the size of original tensor and managed by the users, such as torch.view,
torch.reshape, etc. These nodes have enough information like input sharding_spec and
output sharding_spec to decide how to convert the size value.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
node_pairs = {}
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape):
device_mesh_info[dim] = dim_size
def _extract_target_dim(node):
'''
A helper function to extract the target dimension from size node.
There are two usages of torch.Tensor.size:
1. tensor.size()
2. tensor.size(dim)
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
'''
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
if target_dim < 0:
target_dim += node.args[0]._meta_data.dim()
return target_dim
def _post_processing(node, size_processing_node):
'''
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
'''
# store original node and processing node pair in node_pairs dictioanry
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
if hasattr(node.meta['info'], 'activation_checkpoint'):
MetaInfo(size_processing_node,
mod_dir=node.meta['info'].mod_dir,
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
user_list = list(node.users.keys())
for user in user_list:
if user == size_processing_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with size_processing_node
new_args[new_args.index(node)] = size_processing_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with size_processing_node
new_kwargs[str(node)] = size_processing_node
user.kwargs = new_kwargs
def _update_slice_object_args(slice_object):
'''
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
'''
if isinstance(slice_object, slice):
start = slice_object.start
stop = slice_object.stop
step = slice_object.step
if start in node_pairs:
start = node_pairs[start]
if stop in node_pairs:
stop = node_pairs[stop]
if step in node_pairs:
step = node_pairs[step]
return slice(start, stop, step)
elif isinstance(slice_object, int):
if slice_object in node_pairs:
return node_pairs[slice_object]
else:
return slice_object
else:
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
for node in nodes:
if node.op == 'call_method' and node.target == 'size':
# extract useful information from size node
# dim_partition_dict will instruct the size value on which
# dimension should be enlarged.
sharding_spec = node.args[0].sharding_spec
dim_partition_dict = sharding_spec.dim_partition_dict
target_dim = _extract_target_dim(node)
# insert size_processing node
with mod_graph.inserting_after(node):
size_processing_node = mod_graph.create_node('call_function',
size_processing,
args=(node, dim_partition_dict, device_mesh_info,
target_dim, node.name))
_post_processing(node, size_processing_node)
if node.op == 'call_function' and node.target == operator.getitem:
getitem_index = node.args[1]
# slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int,
# so we do not create a node for slice object. On the other side,
# slice object could take fx.Node as its argument. And the user
# relationship cannot be tracked in fx graph.
# Therefore, I record the node_pairs in this pass, and use the it
# to replace the original node argument inside the slice object if
# it has been processed in above pass.
# There are three main usages of operator.getitem:
# getitem(input, int)
# getitem(input, slice)
# getitem(input, Tuple[slice])
# In this pass, we need process the last two cases because
# node arguments may potentially appear in these cases.
if isinstance(getitem_index, slice):
new_slice_item = _update_slice_object_args(getitem_index)
new_args = (node.args[0], new_slice_item)
node.args = new_args
elif isinstance(getitem_index, (tuple, list)):
if not isinstance(getitem_index[0], slice):
continue
new_slice_items = []
for slice_item in getitem_index:
if slice_item is None:
new_slice_items.append(None)
continue
new_slice_item = _update_slice_object_args(slice_item)
new_slice_items.append(new_slice_item)
new_args = (node.args[0], tuple(new_slice_items))
node.args = new_args
return gm
def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
This pass will process node args to adapt the distributed tensor layout.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
def _extract_info_from_sharding_spec(sharding_spec):
'''
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
'''
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
for element in sharding_spec:
dim_partition_dict.append(_extract_info_from_sharding_spec(element))
return dim_partition_dict, sharding_spec
def _process_node_arguments(node):
new_args = []
for arg in node.args:
# There are two args style:
# 1. (input, *shape)
# 2. (input, shape)
# We will extract the elements from shape and add them into the new_args
# Finally, the args style of new_args will be unified to (input, *shape)
if isinstance(arg, Node):
if isinstance(arg._meta_data, (tuple, list)):
new_args.extend(arg._meta_data)
elif isinstance(arg._meta_data, int):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(arg,
(int, tuple, list)), 'The argument in view node should be either type of Node or int.'
if isinstance(arg, (tuple, list)):
new_args.extend(arg)
else:
new_args.append(arg)
return new_args
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
new_args = _process_node_arguments(node)
if node.op == 'call_method':
args_to_process = list(new_args[1:])
else:
args_to_process = list(new_args)
for dim, shard_dims in dim_partition_dict.items():
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
# we will skip the dim with -1 value
if args_to_process[dim] == -1:
continue
else:
# TODO: add assertion here to make sure the dim size is divisible by total_shard_size
args_to_process[dim] //= total_shard_size
args_to_process = tuple(args_to_process)
if node.op == 'call_method':
new_args = (new_args[0],) + args_to_process
else:
new_args = args_to_process
node.args = new_args
def _filter_node_with_shape_args(node):
if node.op == 'call_method':
target = getattr(node.args[0]._meta_data.__class__, node.target)
elif node.op == 'call_function':
target = node.target
else:
target = None
if target in SHAPE_ARGUMENT_OPS:
return True
return False
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
continue
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
if _filter_node_with_shape_args(node):
_scale_args_adapt_sharding_spec(output_dim_partition_dict, device_mesh, node)
return gm
def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
# This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param, name=None):
comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action, name):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
return True
if node.op == 'get_attr' and isinstance(
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
return True
return False
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
def _shard_param(param, target_sharding_spec):
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParameter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
return param
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters.
if hasattr(target_module, 'processed') and target_module.processed:
continue
setattr(target_module, 'processed', True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
param = _shard_param(param, target_sharding_spec)
setattr(target_module, name, param)
_add_hook_for_grad_communication(node, param, name)
sharded_buffer_dict = {}
# apply the sharding spec of buffers
for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
setattr(buffer, 'sharding_spec', origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
sharded_buffer_dict[name] = buffer_sharded
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())
if node.op == 'get_attr':
root = node.graph.owning_module
atoms = node.target.split(".")
attr_len = len(atoms)
if attr_len == 1:
target_module = root
target = getattr(root, atoms[0])
else:
target_module = root
for atom in atoms[:-1]:
target_module = getattr(target_module, atom)
target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec
target = _shard_param(target, target_sharding_spec)
assert hasattr(target_module, atoms[-1])
setattr(target_module, atoms[-1], target)
_add_hook_for_grad_communication(node, target)
return gm
def implicit_comm_action_apply(gm: torch.fx.GraphModule):
"""
replace the origin kernel into kernel with implicit communication inside.
"""
pass
def runtime_preparation_pass(gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
gm, solution, strategies_constructor)
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
gm = module_params_sharding_pass(gm, device_mesh, overlap=overlap)
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
import operator
import torch
__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
# softmax should not be here
torch.nn.functional.softmax
]
ELEMENTWISE_METHOD_OP = [
torch.Tensor.to,
torch.Tensor.type,
# TODO: contiguous maybe need some extra processes.
torch.Tensor.contiguous
]
RESHAPE_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.transpose,
torch.split,
torch.permute,
operator.getitem,
]
RESHAPE_METHOD_OP = [
torch.Tensor.view,
torch.Tensor.unsqueeze,
torch.Tensor.split,
torch.Tensor.permute,
torch.Tensor.transpose,
]
BCAST_FUNC_OP = [
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d
]
CONV_FUNC_OP = [
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
]
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
NON_PARAM_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
torch.flatten,
torch.where,
operator.pow,
torch.pow,
torch.tanh,
torch.add,
torch.sub,
torch.mul,
torch.div,
torch.floor_divide,
torch.true_divide,
operator.add,
operator.sub,
operator.mul,
operator.floordiv,
operator.truediv,
# softmax should not be here
torch.nn.functional.softmax
]
INFINITY_COST = 1e13
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.fx import GraphModule
from torch.fx.graph import Graph
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
class ModuleWrapper(nn.Module):
'''
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
into the forward function.
'''
def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
'''
Args:
module: the original module
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
'''
super(ModuleWrapper, self).__init__()
self.module = module
self.sharding_spec_dict = sharding_spec_dict
self.origin_spec_dict = origin_spec_dict
self.comm_actions_dict = comm_actions_dict
def forward(self, *args, **kwargs):
return self.module(*args,
sharding_spec_convert_dict=self.sharding_spec_dict,
origin_node_sharding_spec_dict=self.origin_spec_dict,
comm_actions_dict=self.comm_actions_dict,
**kwargs)
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
'''
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
'''
# TODO: implement this function
pass
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
'''
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
from the alpha_beta_dict. These two values will be used to estimate the communication cost.
'''
# TODO: implement this function
pass
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
shard_option: str):
'''
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
'''
if solver_preference == 'standard':
solver_preference = SolverPerference.STANDARD
elif solver_preference == 'tp':
solver_preference = SolverPerference.TP
elif solver_preference == 'dp':
solver_preference = SolverPerference.DP
else:
raise ValueError(f'Invalid solver_preference: {solver_preference}')
if dataloader_option == 'replicated':
dataloader_option = DataloaderOption.REPLICATED
elif dataloader_option == 'distributed':
dataloader_option = DataloaderOption.DISTRIBUTED
else:
raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
if shard_option == 'standard':
shard_option = ShardOption.STANDARD
elif shard_option == 'shard':
shard_option = ShardOption.SHARD
elif shard_option == 'shard_last_axis':
shard_option = ShardOption.SHARD_LAST_AXIS
elif shard_option == 'full_shard':
shard_option = ShardOption.FULL_SHARD
else:
raise ValueError(f'Invalid shard_option: {shard_option}')
solver_options = SolverOptions(solver_perference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
return strategies_constructor
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
'''
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
# liveness_list = graph_analyser.liveness_analysis()
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
return solution
def transform_to_sharded_model(gm: ColoGraphModule,
meta_args: Dict,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap: bool = False):
'''
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
'''
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
solution,
device_mesh,
strategies_constructor,
overlap=overlap)
gm = runtime_apply_pass(gm)
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
gm.recompile()
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
return gm, sharding_spec_dicts
def initialize_device_mesh(world_size: int = -1,
physical_devices: List[int] = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None):
'''
This method is used to initialize the device mesh.
Args:
world_size: the size of device mesh. If the world_size is -1,
the world size will be set to the number of GPUs in the current machine.
physical_devices: the physical devices used to initialize the device mesh.
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
generated by profile_alpha_beta function.
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
'''
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
if physical_devices is None:
physical_devices = [i for i in range(world_size)]
physical_mesh = torch.tensor(physical_devices)
if alpha_beta_dict is None:
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
ab_profiler = AlphaBetaProfiler(physical_devices)
alpha_beta_dict = ab_profiler.alpha_beta_dict
else:
ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict)
if logical_mesh_shape is None and logical_mesh_id is None:
# search for the best logical mesh shape
logical_mesh_id = ab_profiler.search_best_logical_mesh()
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
logical_mesh_shape = logical_mesh_id.shape
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh()
elif logical_mesh_shape is not None and logical_mesh_id is None:
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
logical_mesh_id=logical_mesh_id,
mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta,
init_process_group=True)
return device_mesh
def initialize_model(model: nn.Module,
meta_args: Dict[str, torch.Tensor],
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
return_solution: bool = False):
'''
This method is used to initialize the sharded model which could be used as normal pytorch model.
Args:
model: the model to be sharded.
meta_args: the meta_args is used to specify the input shapes of the model.
device_mesh: the device mesh to execute the model.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
backward computing.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
from the solution_path.
solution_path(optional): the path to save or load the solution.
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
'''
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
graph = tracer.trace(root=model, meta_args=meta_args)
graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__)
shape_prop_pass(gm, *meta_args.values())
gm.recompile()
strategies_constructor = build_strategy_constructor(graph,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
if load_solver_solution:
solution = torch.load(solution_path)
else:
solution = solve_solution(gm, strategies_constructor, memory_budget)
if save_solver_solution:
torch.save(solution, solution_path)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
overlap)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
if return_solution:
solution_to_return = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}')
return model_to_return, solution_to_return
else:
return model_to_return
def autoparallelize(model: nn.Module,
meta_args: Dict[str, torch.Tensor] = None,
data_loader: torch.utils.data.DataLoader = None,
data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
return_solution: bool = False,
memory_budget: float = -1.0):
'''
This method is used to initialize the device mesh, extract the meta_args, and
use them to create a sharded model.
Args:
model: the model to be sharded.
meta_args(optional): the meta_args is used to specify the input shapes of the model.
If the meta_args is None, the meta_args will be extracted from the data_loader.
data_loader(optional): the data_loader to be used in normal training loop.
data_process_func(optional): the data_process_func is used to process the data from the data_loader.
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
generated by profile_alpha_beta function.
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
from the solution_path.
solver_solution_path(optional): the path to save or load the solution.
return_solution(optional): if the return_solution is True, the solution will be returned.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
'''
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
logical_mesh_shape=logical_mesh_shape,
logical_mesh_id=logical_mesh_id)
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
rst_to_unpack = initialize_model(model,
meta_args,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget)
if return_solution:
model, solution = rst_to_unpack
return model, solution
else:
model = rst_to_unpack
return model
from .addmm_handler import ADDMMFunctionHandler
from .batch_norm_handler import BatchNormModuleHandler
from .binary_elementwise_handler import BinaryElementwiseHandler
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .default_reshape_handler import DefaultReshapeHandler
from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler
from .getattr_handler import GetattrHandler
from .getitem_handler import GetItemHandler
from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OutputHandler
from .permute_handler import PermuteHandler
from .placeholder_handler import PlaceholderHandler
from .registry import operator_registry
from .softmax_handler import SoftmaxHandler
from .split_handler import SplitHandler
from .sum_handler import SumHandler
from .tensor_constructor_handler import TensorConstructorHandler
from .transpose_handler import TransposeHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .view_handler import ViewHandler
from .where_handler import WhereHandler
__all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
'SplitHandler'
]
from typing import Dict, List, Union
import torch
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
__all__ = ['ADDMMFunctionHandler']
@operator_registry.register(torch.addmm)
@operator_registry.register(torch.Tensor.addmm)
class ADDMMFunctionHandler(NodeHandler):
"""
This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch.
Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is
no logical-physical shape conversion in this handler.
"""
def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType:
if isinstance(tensor, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
return data_type
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# input operand
input_data = self.node.args[1]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[1]),
type=self._infer_op_data_type(input_data),
data=input_data)
# other operand
other_data = self.node.args[2]._meta_data
physical_other_operand = OperationData(name=str(self.node.args[2]),
type=self._infer_op_data_type(other_data),
data=other_data)
# bias physical shape
bias_logical_shape = self.node._meta_data.shape
bias_data = self.node.args[0]._meta_data
physical_bias_operand = OperationData(name=str(self.node.args[0]),
type=self._infer_op_data_type(bias_data),
data=bias_data,
logical_shape=bias_logical_shape)
# output
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {
"input": physical_input_operand,
"other": physical_other_operand,
"output": physical_output,
'bias': physical_bias_operand
}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm'))
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
bias_op_data = op_data_mapping['bias']
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=bias_op_data,
sharding_spec=bias_sharding_spec)
strategy.communication_actions[bias_op_data] = comm_action
return strategy
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
from .node_handler import MetaInfoModuleHandler, ModuleHandler
from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
__all__ = ['BatchNormModuleHandler']
@operator_registry.register(torch.nn.BatchNorm1d)
@operator_registry.register(torch.nn.BatchNorm2d)
@operator_registry.register(torch.nn.BatchNorm3d)
class BatchNormModuleHandler(MetaInfoModuleHandler):
"""
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BatchNormStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=self.named_parameters['weight'].shape)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
physical_running_mean_operand = OperationData(name="running_mean",
type=OperationDataType.BUFFER,
data=self.named_buffers['running_mean'],
logical_shape=self.named_buffers['running_mean'].shape)
physical_running_var_operand = OperationData(name="running_var",
type=OperationDataType.BUFFER,
data=self.named_buffers['running_var'],
logical_shape=self.named_buffers['running_var'].shape)
physical_num_batches_tracked_operand = OperationData(
name="num_batches_tracked",
type=OperationDataType.BUFFER,
data=self.named_buffers['num_batches_tracked'],
logical_shape=self.named_buffers['num_batches_tracked'].shape)
mapping = {
"input": physical_input_operand,
"other": physical_other_operand,
"output": physical_output,
"running_mean": physical_running_mean_operand,
"running_var": physical_running_var_operand,
"num_batches_tracked": physical_num_batches_tracked_operand
}
if self.named_parameters['bias'] is not None:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
return mapping
from typing import Dict, List, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..constants import BCAST_FUNC_OP
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
__all__ = ['BinaryElementwiseHandler']
@operator_registry.register(BCAST_FUNC_OP)
class BinaryElementwiseHandler(MetaInfoNodeHandler):
"""
An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add.
"""
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
bcast_shape = self.node._meta_data.shape
def _get_op_data_type(tensor):
if isinstance(tensor, torch.nn.parameter.Parameter):
return OperationDataType.PARAM
else:
return OperationDataType.ARG
def _get_arg_value(idx):
non_tensor = False
if isinstance(self.node.args[idx], Node):
meta_data = self.node.args[idx]._meta_data
# The meta_data of node type argument could also possibly be a non-tensor object.
if not isinstance(meta_data, torch.Tensor):
assert isinstance(meta_data, (int, float))
meta_data = torch.Tensor([meta_data]).to('meta')
non_tensor = True
else:
# this is in fact a real data like int 1
# but we can deem it as meta data
# as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float))
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
non_tensor = True
return meta_data, non_tensor
input_meta_data, non_tensor_input = _get_arg_value(0)
other_meta_data, non_tensor_other = _get_arg_value(1)
output_meta_data = self.node._meta_data
# we need record op_data with non-tensor data in this list,
# and filter the non-tensor op_data in post_process.
self.non_tensor_list = []
# assert False
input_op_data = OperationData(name=str(self.node.args[0]),
type=_get_op_data_type(input_meta_data),
data=input_meta_data,
logical_shape=bcast_shape)
other_op_data = OperationData(name=str(self.node.args[1]),
type=_get_op_data_type(other_meta_data),
data=other_meta_data,
logical_shape=bcast_shape)
output_op_data = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=bcast_shape)
if non_tensor_input:
self.non_tensor_list.append(input_op_data)
if non_tensor_other:
self.non_tensor_list.append(other_op_data)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BinaryElementwiseStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
for op_name, op_data in op_data_mapping.items():
if op_data in self.non_tensor_list:
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
strategy.sharding_specs.pop(op_data)
else:
# convert the logical sharding spec to physical sharding spec if broadcast
# e.g. torch.rand(4, 4) + torch.rand(4)
physical_shape = op_data.data.shape
logical_shape = op_data.logical_shape
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
sharding_spec, logical_shape, physical_shape)
strategy.sharding_specs[op_data] = sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=op_data,
sharding_spec=sharding_spec)
strategy.communication_actions[op_data] = comm_action
return strategy
from typing import Dict, List, Union
import torch
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler']
def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
"""
This function is a helper function which extracts the common logic for both `bmm` and `addbmm`
node handler to reduce code redundancy.
"""
# input operand
physical_input_operand = OperationData(name=str(node.args[input_idx]),
type=OperationDataType.ARG,
data=node.args[input_idx]._meta_data)
# other operand
physical_other_operand = OperationData(name=str(node.args[other_idx]),
type=OperationDataType.ARG,
data=node.args[other_idx]._meta_data)
# output
physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if bias_idx is not None:
# bias physical shape
bias_logical_shape = node._meta_data.shape
physical_bias_operand = OperationData(name=str(node.args[bias_idx]),
type=OperationDataType.ARG,
data=node.args[bias_idx]._meta_data,
logical_shape=bias_logical_shape)
mapping['bias'] = physical_bias_operand
return mapping
@operator_registry.register(torch.bmm)
@operator_registry.register(torch.Tensor.bmm)
class BMMFunctionHandler(NodeHandler):
"""
This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch.
Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is
no logical-physical shape conversion in this handler.
"""
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = _get_data_mapping_for_bmm_op(node=self.node, input_idx=0, other_idx=1)
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
@operator_registry.register(torch.addbmm)
@operator_registry.register(torch.Tensor.addbmm)
class AddBMMFunctionHandler(NodeHandler):
"""
This is a NodeHandler class which deals with the addition + batched matrix multiplication operation in PyTorch.
Such operations including `torch.addbmm` and `torch.Tensor.addbmm` require the two matmul tensor to be 3D. However, due to the
addition, logical-physical shape conversion is required for the bias term.
As the addbmm operation will reduce the batch dimension, the bias is maximum 2D.
"""
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = _get_data_mapping_for_bmm_op(node=self.node, input_idx=1, other_idx=2, bias_idx=0)
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generator = BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)
# addbmm will shrink the first batch dim
generator.squeeze_batch_dim = True
generators.append(generator)
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
if 'bias' in op_data_mapping:
bias_op_data = op_data_mapping['bias']
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=bias_op_data,
sharding_spec=bias_sharding_spec)
strategy.communication_actions[bias_op_data] = comm_action
return strategy
from typing import Dict, List
import torch
import torch.nn.functional as F
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from ..utils import transpose_partition_dim
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
__all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
@operator_registry.register(torch.nn.Conv1d)
@operator_registry.register(torch.nn.Conv2d)
@operator_registry.register(torch.nn.Conv3d)
class ConvModuleHandler(MetaInfoModuleHandler):
"""
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
logical_shape_for_weight = list(self.named_parameters["weight"].shape)
logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
1], logical_shape_for_weight[0]
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=torch.Size(logical_shape_for_weight))
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.named_parameters:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == "weight":
transpose_partition_dim(sharding_spec, 0, 1)
return strategy
@operator_registry.register(F.conv1d)
@operator_registry.register(F.conv2d)
@operator_registry.register(F.conv3d)
class ConvFunctionHandler(MetaInfoNodeHandler):
"""
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
logical_shape_for_weight = list(self.node.args[1]._meta_data.shape)
logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
1], logical_shape_for_weight[0]
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=torch.Size(logical_shape_for_weight))
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
type=data_type,
data=self.node.kwargs["bias"]._meta_data)
mapping['bias'] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == str(self.node.args[1]):
transpose_partition_dim(sharding_spec, 0, 1)
return strategy
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import DefaultReshapeGenerator, StrategyGenerator
__all__ = ['DefaultReshapeHandler']
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class DefaultReshapeHandler(MetaInfoNodeHandler):
"""
A DefaultReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(DefaultReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def infer_logical_shape(self, data):
"""
This function is used to infer logical shape for operands.
Notes: This function is only used for the operands whose data are not only in type of tensor,
such as tuple of tensor.
"""
if isinstance(data, torch.Tensor):
return data.shape
else:
assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor."
logical_shape = []
for tensor in data:
assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor."
logical_shape.append(tensor.shape)
logical_shape = tuple(logical_shape)
return logical_shape
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
input_logical_shape = self.infer_logical_shape(input_data)
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=data_type,
data=input_data,
logical_shape=input_logical_shape)
output_data = self.node._meta_data
output_logical_shape = self.infer_logical_shape(output_data)
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_data,
logical_shape=output_logical_shape)
mapping = {"input": physical_input_operand, "output": physical_output}
return mapping
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment