"vscode:/vscode.git/clone" did not exist on "cd5961b5dad560d63f4dd42d08d6ee3877b82003"
Commit 144fd688 authored by zhaoying1's avatar zhaoying1
Browse files

Added bitsandbytes

parent 387082e1
Pipeline #328 canceled with stages
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
class LAMB(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adam_w_mode=True,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=False,
max_unorm=1.0,
):
super(LAMB, self).__init__(
"lamb",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
max_unorm=1.0,
)
class LAMB8bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adam_w_mode=True,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=False,
max_unorm=1.0,
):
super(LAMB8bit, self).__init__(
"lamb",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
max_unorm=1.0,
)
class LAMB32bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adam_w_mode=True,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=False,
max_unorm=1.0,
):
super(LAMB32bit, self).__init__(
"lamb",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
max_unorm=1.0,
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.optim import Optimizer
from bitsandbytes.optim.optimizer import Optimizer1State
class LARS(Optimizer1State):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
if momentum == 0:
raise NotImplementedError(
f"LARS without momentum is not supported!"
)
super(LARS, self).__init__(
"lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
class LARS8bit(Optimizer1State):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
if momentum == 0:
raise NotImplementedError(
f"LARS without momentum is not supported!"
)
super(LARS8bit, self).__init__(
"lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
class LARS32bit(Optimizer1State):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
if momentum == 0:
raise NotImplementedError(
f"LARS without momentum is not supported!"
)
super(LARS32bit, self).__init__(
"lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
class PytorchLARS(Optimizer):
def __init__(
self,
params,
lr=0.01,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
max_unorm=0.02,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
max_unorm=max_unorm,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening"
)
super(PytorchLARS, self).__init__(params, defaults)
def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
d_p_list = []
momentum_buffer_list = []
weight_decay = group["weight_decay"]
momentum = group["momentum"]
dampening = group["dampening"]
nesterov = group["nesterov"]
max_unorm = group["max_unorm"]
lr = group["lr"]
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p]
d_p = p.grad
if weight_decay != 0:
d_p = d_p.add(param, alpha=weight_decay)
if momentum != 0:
buf = state.get("momentum_buffer", None)
if buf is None:
buf = torch.clone(d_p).detach()
state["momentum_buffer"] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
update = d_p + buf * momentum
else:
update = buf
update_scale = 1.0
if max_unorm > 0.0:
assert p.dtype == torch.float32
pnorm = torch.norm(p.detach())
unorm = torch.norm(update)
if unorm > max_unorm * pnorm:
update_scale = max_unorm * pnorm / unorm
p.add_(update, alpha=-lr * update_scale)
return loss
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import abc as container_abcs
from collections import defaultdict
from copy import deepcopy
from itertools import chain
import torch
import bitsandbytes.functional as F
class MockArgs(object):
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
class GlobalOptimManager(object):
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.pid2config = {}
self.index2config = {}
self.optimizer = None
self.uses_config_override = False
self.module_weight_config_triple = []
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
def register_parameters(self, params):
param_groups = list(params)
if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}]
for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[
id(p)
]
def override_config(
self, parameters, key=None, value=None, key_value_dict=None
):
"""
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
This can be both, optimizer parameters like "betas", or "lr" or it can be
8-bit specific paramters like "optim_bits", "percentile_clipping".
Parameters
----------
parameters : torch.Tensor or list(torch.Tensors)
The input parameters.
key : str
The hyperparamter to override.
value : object
The value for the hyperparamters.
key_value_dict : dict
A dictionary with multiple key-values to override.
"""
self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter):
parameters = [parameters]
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
if key is not None and value is not None:
assert key_value_dict is None
key_value_dict = {key: value}
if key_value_dict is not None:
for p in parameters:
if id(p) in self.pid2config:
self.pid2config[id(p)].update(key_value_dict)
else:
self.pid2config[id(p)] = key_value_dict
def register_module_override(self, module, param_name, config):
self.module_weight_config_triple.append((module, param_name, config))
class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
self.initialized = False
self.name2qmap = {}
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set(
[
"qmap1",
"qmap2",
"max1",
"max2",
"new_max1",
"new_max2",
"state1",
"state2",
"gnorm_vec",
"absmax1",
"absmax2",
"unorm_vec",
]
)
if optim_bits == 8:
self.fill_qmap()
def fill_qmap(self):
self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state)
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups):
raise ValueError(
"loaded state dict has a different number of "
"parameter groups"
)
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
# Update the state
id_map = {
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in groups)),
)
}
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if param.is_floating_point() and value.dtype != torch.uint8:
value = value.to(param.dtype)
return value
elif isinstance(value, dict):
for k, v in value.items():
if k in self.non_castable_tensor_keys:
value[k] = v.to(param.device)
else:
value[k] = cast(param, v)
return value
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
else:
state[k] = v
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group["params"] = group["params"]
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)
]
self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p in self.state:
values = self.state[p]
for k, v in values.items():
if isinstance(v, torch.Tensor):
self.state[p][k] = v.to(p.device)
def check_overrides(self):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance(
pmodule, torch.Parameter
)
found = False
for gindex, group in enumerate(self.param_groups):
if found:
break
for pindex, p in enumerate(group["params"]):
if found:
break
if id(p) == id(pmodule):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
self.mng.index2config[
(gindex, pindex)
] = self.mng.pid2config[id(p)]
found = True
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
overflows = []
if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self.init_state(group, p, gindex, pindex)
self.update_step(group, p, gindex, pindex)
return loss
def get_config(self, gindex, pindex, group):
config = {}
config["betas"] = group["betas"]
config["eps"] = group["eps"]
config["weight_decay"] = group["weight_decay"]
config["lr"] = group["lr"]
config["optim_bits"] = self.args.optim_bits
config["min_8bit_size"] = self.args.min_8bit_size
config["percentile_clipping"] = self.args.percentile_clipping
config["block_wise"] = self.args.block_wise
config["max_unorm"] = self.args.max_unorm
config["skip_zeros"] = self.args.skip_zeros
if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)])
return config
def init_state(self, group, p, gindex, pindex):
raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex):
raise NotImplementedError(
f"The update_step method needs to be overidden"
)
class Optimizer2State(Optimizer8bit):
def __init__(
self,
optimizer_name,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if isinstance(betas, str):
# format: '(beta1, beta2)'
betas = betas.replace("(", "").replace(")", "").strip().split(",")
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(
f"Invalid beta parameter at index {i}: {betas[i]}"
)
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
args["optim_bits"] = optim_bits
args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise
args["max_unorm"] = max_unorm
args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
self.args = args
self.optimizer_name = optimizer_name
@torch.no_grad()
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
if config["optim_bits"] == 32:
dtype = torch.float32
elif config["optim_bits"] == 8:
dtype = torch.uint8
else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
state = self.state[p]
state["step"] = 0
if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096
):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
p.device
)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
p.device
)
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"]
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap2"] = self.name2qmap["udynamic"]
if config["block_wise"]:
n = p.numel()
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
state["absmax2"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
else:
state["max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
if config["max_unorm"] > 0.0:
state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
state = self.state[p]
grad = p.grad
config = self.get_config(gindex, pindex, group)
state["step"] += 1
step = state["step"]
if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
)
else:
gnorm_scale = 1.0
if state["state1"].dtype == torch.float:
F.optimizer_update_32bit(
self.optimizer_name,
grad,
p,
state["state1"],
config["betas"][0],
config["eps"],
step,
config["lr"],
state["state2"],
config["betas"][1],
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
skip_zeros=config["skip_zeros"],
)
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
F.optimizer_update_8bit(
self.optimizer_name,
grad,
p,
state["state1"],
state["state2"],
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
state["qmap2"],
state["max1"],
state["max2"],
state["new_max1"],
state["new_max2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
unorm_vec=state["unorm_vec"]
if config["max_unorm"] > 0.0
else None,
max_unorm=config["max_unorm"],
)
# swap maxes
state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
state["max2"], state["new_max2"] = state["new_max2"], state["max2"]
elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
F.optimizer_update_8bit_blockwise(
self.optimizer_name,
grad,
p,
state["state1"],
state["state2"],
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
state["qmap2"],
state["absmax1"],
state["absmax2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
skip_zeros=config["skip_zeros"],
)
class Optimizer1State(Optimizer8bit):
def __init__(
self,
optimizer_name,
params,
lr=1e-3,
betas=(0.9, 0.0),
eps=1e-8,
weight_decay=0.0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(
f"Invalid beta parameter at index {i}: {betas[i]}"
)
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
args["optim_bits"] = optim_bits
args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise
args["max_unorm"] = max_unorm
args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
self.args = args
self.optimizer_name = optimizer_name
@torch.no_grad()
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
if config["optim_bits"] == 32:
dtype = torch.float32
elif config["optim_bits"] == 8:
dtype = torch.uint8
else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
state = self.state[p]
state["step"] = 0
if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096
):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
p.device
)
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"]
if config["block_wise"]:
n = p.numel()
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
else:
state["max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
if config["max_unorm"] > 0.0:
state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
state = self.state[p]
grad = p.grad
config = self.get_config(gindex, pindex, group)
state["step"] += 1
step = state["step"]
if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
)
else:
gnorm_scale = 1.0
if state["state1"].dtype == torch.float:
F.optimizer_update_32bit(
self.optimizer_name,
grad,
p,
state["state1"],
config["betas"][0],
config["eps"],
step,
config["lr"],
None,
0.0,
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
skip_zeros=config["skip_zeros"],
)
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
F.optimizer_update_8bit(
self.optimizer_name,
grad,
p,
state["state1"],
None,
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
None,
state["max1"],
None,
state["new_max1"],
None,
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
)
state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
F.optimizer_update_8bit_blockwise(
self.optimizer_name,
grad,
p,
state["state1"],
None,
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
None,
state["absmax1"],
None,
config["weight_decay"],
gnorm_scale=gnorm_scale,
skip_zeros=config["skip_zeros"],
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
class RMSprop(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0:
raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop, self).__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class RMSprop8bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0:
raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class RMSprop32bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0:
raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
class SGD(Optimizer1State):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD, self).__init__(
"momentum",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class SGD8bit(Optimizer1State):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD8bit, self).__init__(
"momentum",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class SGD32bit(Optimizer1State):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD32bit, self).__init__(
"momentum",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
import shlex
import subprocess
from typing import Tuple
def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple):
return tuple(
to_decode.decode("UTF-8").strip()
for to_decode in subprocess_err_out_tuple
)
def execute_and_return_decoded_std_streams(command_string):
return _decode(
subprocess.Popen(
shlex.split(command_string),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
).communicate()
)
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err
# Compiling from source
Basic steps.
1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cpuonly`
2. `CUDA_VERSION=XXX python setup.py install`
To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive).
For your convenience, there is an installation script in the root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands:
```bash
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64/" >> ~/.bashrc
echo "export PATH=$PATH:/usr/local/cuda/bin/" >> ~/.bashrc
source ~/.bashrc
```
By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler.
Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed
If you have problems compiling the library with these instructions from source, please open an issue.
#include <common.h>
#include <float.h>
void *quantize_block(void *arguments) {
// 1. find absmax in block
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
// 4. check minimal distance
// 5. store index
struct quantize_block_args *args = (quantize_block_args *) arguments;
// 1. find absmax in block
float absmax_block = -FLT_MAX;
for (long long i = args->block_idx; i < args->block_end; i++)
absmax_block = fmax(absmax_block, fabs(args->A[i]));
args->absmax[args->block_idx / args->blocksize] = absmax_block;
for (long long i = args->block_idx; i < args->block_end; i++) {
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
float normed_value = args->A[i] / absmax_block;
long long idx = args->bin_searcher->scalar(normed_value);
// 4. check minimal distance
// The binary search returns always the value to the left, which might not be the closest value
if (idx < 255) {
float dist_left = fabs(normed_value - (args->code[idx]));
float dist_right = fabs(normed_value - (args->code[idx + 1]));
if (dist_right < dist_left) { idx += 1; }
}
// 5. store index
args->out[i] = (unsigned char) idx;
}
return NULL;
}
#include <BinSearch.h>
#ifndef common
#define common
using namespace BinSearch;
#define BLOCK_SIZE 16384
struct quantize_block_args {
BinAlgo<Scalar, float, Direct2> *bin_searcher;
float *code;
float *A;
float *absmax;
unsigned char *out;
long long block_end;
long long block_idx;
long long threadidx;
long long blocksize;
};
void *quantize_block(void *arguments);
#endif
#include <BinSearch.h>
#include <pthread.h>
#include <common.h>
using namespace BinSearch;
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) {
for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items;
for (long long i = block_idx; i < block_end; i++)
out[i] = code[A[i]] * absmax[block_idx / blocksize];
}
}
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n)
{
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f;
long long num_blocks = n / blocksize;
num_blocks += n % blocksize == 0 ? 0 : 1;
const uint32 elements_code = 256;
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
int thread_wave_size = 256;
// we chunk the thresds into waves of 256 since the max limit is
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
{
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
for(long long i = 0; i < valid_chunks; i++)
args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
int chunks_processed = 0;
for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
{
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items;
struct quantize_block_args *arg = args[chunks_processed];
arg->bin_searcher = &bin_searcher;
arg->code = code;
arg->A = A;
arg->absmax = absmax;
arg->out = out;
arg->block_end = block_end;
arg->block_idx = block_idx;
arg->threadidx = block_idx / blocksize;
arg->blocksize = blocksize;
pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
chunks_processed += 1;
if(chunks_processed == valid_chunks){ break; }
}
for (int i = 0; i < valid_chunks; i++)
int err = pthread_join(threads[i], NULL);
free(threads);
for (int i = 0; i < valid_chunks; i++)
free(args[i]);
free(args);
}
}
#ifndef BITSANDBYTES_CPU_OPS_H
#define BITSANDBYTES_CPU_OPS_H
#include <iostream>
#include <stdio.h>
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n);
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n);
#endif
\ No newline at end of file
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <hip/hip_runtime.h>
#include <kernels.cuh>
#include <hipcub/hipcub.hpp>
#include <hipcub/block/block_radix_sort.hpp>
#include <hipcub/warp/warp_reduce.hpp>
#include <hipcub/block/block_load.hpp>
#include <hipcub/block/block_discontinuity.hpp>
#include <hipcub/block/block_store.hpp>
#include <hipcub/block/block_reduce.hpp>
#include <hip/hsa_detail/hip_math_constants.h>
#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
// __device__ float atomicMax(float* address, float val) {
// int* address_as_i = reinterpret_cast<int*>(address);
// int old = *address_as_i, assumed;
// do {
// assumed = old;
// old = atomicCAS(
// reinterpret_cast<int*>(address), assumed,
// __float_as_int(fmaxf(val, __int_as_float(assumed))));
// } while (assumed != old);
// return __int_as_float(old);
// }
// __device__ float atomicMin(float* address, float val) {
// int* address_as_i = reinterpret_cast<int*>(address);
// int old = *address_as_i, assumed;
// do {
// assumed = old;
// old = atomicCAS(
// reinterpret_cast<int*>(address), assumed,
// __float_as_int(fminf(val, __int_as_float(assumed))));
// } while (assumed != old);
// return __int_as_float(old);
// }
template <int STOCHASTIC>
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{
int pivot = 127;
int upper_pivot = 255;
int lower_pivot = 0;
float lower = -1.0f;
float upper = 1.0f;
float val = smem_code[pivot];
// i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1)
{
if(x > val)
{
lower_pivot = pivot;
lower = val;
pivot+=i;
}
else
{
upper_pivot = pivot;
upper = val;
pivot-=i;
}
val = smem_code[pivot];
}
if(upper_pivot == 255)
upper = smem_code[upper_pivot];
if(lower_pivot == 0)
lower = smem_code[lower_pivot];
if(!STOCHASTIC)
{
if(x > val)
{
float midpoint = (upper+val)*0.5f;
if(x > midpoint)
{
return upper_pivot;
}
else
return pivot;
}
else
{
float midpoint = (lower+val)*0.5f;
if(x < midpoint)
return lower_pivot;
else
return pivot;
}
}
else
{
if(x > val)
{
float dist_to_upper = fabsf(upper-x);
float dist_full = upper-val;
if(rand >= dist_to_upper/dist_full) return upper_pivot;
else return pivot;
}
else
{
float dist_to_lower = fabsf(lower-x);
float dist_full = val-lower;
if(rand >= dist_to_lower/dist_full) return lower_pivot;
else return pivot;
}
}
}
template <int SIGNED>
__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x)
{
int pivot = 127;
int upper_pivot = 255;
int lower_pivot = 0;
float lower = SIGNED ? -1.0f : 0.0f;
float upper = 1.0f;
float midpoint;
float val = quadrants[1];
int local_pivot = 1;
int offset = 1;
// i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1)
{
if(x > val)
{
lower_pivot = pivot;
lower = val;
pivot+=i;
//val = i == 64 ? quadrants[2] : smem_code[pivot];
local_pivot += offset;
}
else
{
upper_pivot = pivot;
upper = val;
pivot-=i;
//val = i == 64 ? quadrants[0] : smem_code[pivot];
local_pivot -= offset;
}
val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot];
offset -= 1;
}
if(x > val)
{
midpoint = (upper+val)*0.5f;
if(x > midpoint)
return upper_pivot;
else
return pivot;
}
else
{
midpoint = (lower+val)*0.5f;
if(x < midpoint)
return lower_pivot;
else
return pivot;
}
}
template <int SIGNED>
__device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper)
{
int lower_pivot = QUADRANT*16-1 - 0;
int pivot = QUADRANT*16-1 + 16;
int upper_pivot = QUADRANT*16-1 + 31;
float val = midpoint;
// i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 16; i > 0; i>>=1)
{
if(x > val)
{
lower_pivot = pivot;
lower = val;
pivot+=i;
}
else
{
upper_pivot = pivot;
upper = val;
pivot-=i;
}
val = smem_code[pivot];
}
if(x > val)
{
midpoint = (upper+val)*0.5f;
if(x > midpoint)
return upper_pivot;
else
return pivot;
}
else
{
midpoint = (lower+val)*0.5f;
if(x < midpoint)
return lower_pivot;
else
return pivot;
}
}
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n)
{
const int tid = threadIdx.x + (blockDim.x*blockIdx.x);
const int numThreads = blockDim.x*gridDim.x;
for(int i = tid; i < n; i+=numThreads)
{
int idx = (index1[i]*maxidx1) + index2[i];
atomicAdd(&histogram[idx], src[i]);
}
}
template<typename T, int BLOCK_SIZE, int NUM_MAX>
__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n)
{
typedef hipcub::WarpReduce<T> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/8 , 8, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
__shared__ typename LoadT::TempStorage loadt;
const int warp_idx = threadIdx.x/32;
const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE);
// BLOCK_SIZE/32 == number of warps
__shared__ int smem_max_indices[8*BLOCK_SIZE/32];
__shared__ float smem_max_values[8*BLOCK_SIZE/32];
T values[8];
T max1 = -64000.0f;
T max2 = -64000.0f;
int max_idx1 = -1;
int max_idx2 = -1;
int sign1 = -1;
int sign2 = -1;
// 1. load 8 values per thread
// 2. compute 2-max in registers (64 max per warp)
// 3. do warp reduction + broadcast back
// 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest
// 5. Repeat (3) 8 times for top 8 values in 256
// 6. store with byte index
LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f);
#pragma unroll 8
for(int i = 0; i < 8; i++)
{
T absval = fabsf(values[i]);
if(absval > max1)
{
max1 = values[i];
sign1 = signbit(values[i]);
max_idx1 = 8*threadIdx.x + i;
}
else if(absval > max2)
{
max2 = values[i];
sign2 = signbit(values[i]);
max_idx2 = 8*threadIdx.x + i;
}
}
float warp_max;
for(int i = 0; i < 8; i++)
{
// 3. do warp reduction + broadcast back
warp_max = WarpReduce(temp_storage).Reduce(max1, hipcub::Max());
warp_max = hipcub::ShuffleIndex<32>(warp_max, 0, 0xffffffff);
// 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest
if(warp_max == max1)
{
smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1;
smem_max_indices[warp_idx*8 + i] = max_idx1;
sign1 = sign2;
max1 = max2;
max_idx1 = max_idx2;
max2 = -64000.0f;
}
// __syncwarp();
__syncthreads();
}
if(threadIdx.x % 32 < 8)
{
// offset: 8 values per 256 input values
//
int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8;
}
}
#define THREADS_ESTIMATE 512
#define NUM_ESTIMATE 8
#define BLOCK_ESTIMATE 4096
template<typename T>
__launch_bounds__(THREADS_ESTIMATE, 1)
__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n)
{
const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE);
int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE;
const int base_idx = (blockIdx.x * BLOCK_ESTIMATE);
const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE));
T vals[NUM_ESTIMATE];
typedef hipcub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, hipcub::NullType, 4, true, hipcub::BLOCK_SCAN_RAKING> BlockRadixSort;
typedef hipcub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
__shared__ union {
typename LoadFloat::TempStorage loadf;
typename BlockRadixSort::TempStorage sort;
int smem_qidx[BLOCK_ESTIMATE];
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE)
{
valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i;
// do not process half-blocks
if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; }
#pragma unroll 4
for(int j = 0; j < NUM_ESTIMATE; j++)
vals[j] = max_val;
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items);
#pragma unroll 4
for(int j = 0; j < NUM_ESTIMATE; j++)
vals[j] = ((float)vals[j]) * reciprocal_num_blocks;
__syncthreads();
// sort into striped pattern to mitigate bank conflicts
// striped pattern index for thread 0 [0, 1024, 2048, 3096]
// striped pattern index for thread 1 [1, 1025, 2049, 3097]
BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals);
__syncthreads();
for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x)
temp_storage.smem_qidx[j] = -1;
if(threadIdx.x < 256)
{
float q_interval = (1.0f-(2.0f*offset))/255.0f;
int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1)));
temp_storage.smem_qidx[local_idx] = threadIdx.x;
}
__syncthreads();
for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x)
{
if(temp_storage.smem_qidx[i] != -1)
atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]);
}
}
}
__launch_bounds__(TH, 4)
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n)
{
const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK;
const int base_idx = (blockIdx.x * NUM_BLOCK);
float vals[NUM];
unsigned char qvals[NUM];
//const int lane_id = threadIdx.x % 2;
typedef hipcub::BlockLoad<float, TH, NUM, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockStore<unsigned char, TH, NUM, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
__shared__ typename LoadFloat::TempStorage loadf;
__shared__ typename StoreChar::TempStorage storec;
__shared__ float smem_code[256];
//__shared__ float smem_code[2][257];
if(threadIdx.x < 256)
{
smem_code[threadIdx.x] = code[threadIdx.x];
//smem_code[0][threadIdx.x] = code[threadIdx.x];
//smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x];
}
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK)
{
// number of values already processed in blocks +
// number of values already processed in this block +
// rand_offset % mod value
valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i;
__syncthreads();
LoadFloat(loadf).Load(&(A[i]), vals, valid_items);
#pragma unroll 4
for(int j = 0; j < NUM; j++)
qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]);
__syncthreads();
StoreChar(storec).Store(&(out[i]), qvals, valid_items);
}
}
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
T vals[NUM_PER_TH];
float rand_vals[NUM_PER_TH];
unsigned char qvals[NUM_PER_TH];
//float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
int local_rand_idx = 0;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
__shared__ typename LoadT::TempStorage loadt;
__shared__ typename LoadFloat::TempStorage loadf;
__shared__ typename StoreChar::TempStorage storec;
__shared__ typename BlockReduce::TempStorage reduce;
__shared__ float smem_code[256];
__shared__ float smem_absmax_value[1];
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
smem_code[i] = code[i];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_abs_max = -FLT_MAX;
__syncthreads();
LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f);
// 1. compute local max
// 2. broadcast local max
// 3. normalize inputs and quantize
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, hipcub::Max(), valid_items);
if(threadIdx.x == 0)
smem_absmax_value[0] = local_abs_max;
__syncthreads();
if(threadIdx.x == 0)
absmax[i/BLOCK_SIZE] = local_abs_max;
else
local_abs_max = smem_absmax_value[0];
// __syncwarp();
__syncthreads();
local_abs_max = 1.0f/local_abs_max;
if(STOCHASTIC)
{
local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4);
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
}
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
if(!STOCHASTIC)
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
else
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
}
__syncthreads();
StoreChar(storec).Store(&(out[i]), qvals, valid_items);
}
}
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n)
{
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
T vals[NUM_PER_TH];
unsigned char qvals[NUM_PER_TH];
float local_abs_max = -FLT_MAX;
typedef hipcub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<T, THREADS, NUM_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet;
//__shared__ float smem_code[256];
//float local_code[16];
//if(threadIdx.x < 256)
//smem_code[threadIdx.x] = code[threadIdx.x];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_abs_max = absmax[i/BLOCK_SIZE];
__syncthreads();
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
__syncthreads();
StoreT(storet).Store(&(out[i]), vals, valid_items);
}
}
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
{
const unsigned int numThreads = blockDim.x * gridDim.x;
const int idx = (blockIdx.x * blockDim.x) + threadIdx.x;
__shared__ float smem_code[256];
if(threadIdx.x < 256)
{
smem_code[threadIdx.x] = code[threadIdx.x];
}
__syncthreads();
for (int i = idx;i < n; i += numThreads)
{
out[i] = smem_code[A[i]];
}
}
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
int valid_items = 0;
T g_vals[NUM_VALS];
float s1_vals[NUM_VALS];
float s2_vals[NUM_VALS];
const float correction1 = 1.0f/(1.0f - powf(beta1, step));
const float correction2 = 1.0f/(1.0f - powf(beta2, step));
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
__shared__ union {
typename Load::TempStorage load;
typename LoadFloat::TempStorage loadf;
typename BlockReduce::TempStorage reduce;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
{
switch(OPTIMIZER)
{
case ADAM:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
s1_vals[j] *= correction1;
s2_vals[j] *= correction2;
s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update)
break;
}
}
# pragma unroll NUM_VALS-1
for(unsigned int j = 1; j < NUM_VALS; j++)
s1_vals[0] += s1_vals[j];
__syncthreads();
s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]);
if(threadIdx.x == 0)
atomicAdd(&unorm[0], s1_vals[0]);
// __syncwarp();
__syncthreads();
}
}
#define NUM_PER_THREAD 4
template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = 0;
float update_scale = 0.0f;
T g_vals[NUM_PER_THREAD];
T p_vals[NUM_PER_THREAD];
float s1_vals[NUM_PER_THREAD];
float s2_vals[NUM_PER_THREAD];
const float correction1 = 1.0f - powf(beta1, step);
const float correction2 = sqrtf(1.0f - powf(beta2, step));
const float step_size = -lr*correction2/correction1;
if(max_unorm > 0.0f)
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
else{ update_scale = 1.0f; }
}
else{ update_scale = 1.0f; }
typedef hipcub::BlockLoad<T, TH, NUM_PER_THREAD, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockStore<T, TH, NUM_PER_THREAD, hipcub::BLOCK_STORE_WARP_TRANSPOSE> Store;
typedef hipcub::BlockLoad<float, TH, NUM_PER_THREAD, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockStore<float, TH, NUM_PER_THREAD, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
__shared__ union {
typename Load::TempStorage load;
typename Store::TempStorage store;
typename LoadFloat::TempStorage loadf;
typename StoreFloat::TempStorage storef;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items);
__syncthreads();
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
switch(OPTIMIZER)
{
case ADAM:
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
if(weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
}
break;
}
}
__syncthreads();
Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items);
}
}
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
int valid_items = 0;
T g_vals[NUM_VALS];
float s1_vals[NUM_VALS];
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
__shared__ union {
typename Load::TempStorage load;
typename LoadFloat::TempStorage loadf;
typename BlockReduce::TempStorage reduce;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
{
switch(OPTIMIZER)
{
case MOMENTUM:
if(step == 1)
s1_vals[j] = (float)g_vals[j]; // state update
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
}
}
# pragma unroll
for(unsigned int j = 1; j < NUM_VALS; j++)
s1_vals[0] += s1_vals[j];
__syncthreads();
s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items);
if(threadIdx.x == 0)
atomicAdd(&unorm[0], s1_vals[0]);
// __syncwarp();
__syncthreads();
}
}
template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p,
float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = 0;
float update_scale = 0.0f;
if(max_unorm > 0.0f)
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; }
else{ update_scale = 1.0f; }
}
else{ update_scale = 1.0f; }
T g_vals[NUM_PER_THREAD];
T p_vals[NUM_PER_THREAD];
float s1_vals[NUM_PER_THREAD];
typedef hipcub::BlockLoad<T, TH, NUM_PER_THREAD, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockStore<T, TH, NUM_PER_THREAD, hipcub::BLOCK_STORE_WARP_TRANSPOSE> Store;
typedef hipcub::BlockLoad<float, TH, NUM_PER_THREAD, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockStore<float, TH, NUM_PER_THREAD, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
__shared__ union {
typename Load::TempStorage load;
typename Store::TempStorage store;
typename LoadFloat::TempStorage loadf;
typename StoreFloat::TempStorage storef;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
__syncthreads();
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
g_vals[j] = gnorm_scale*((float)g_vals[j]);
if(weight_decay > 0.0f)
g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay);
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
switch(OPTIMIZER)
{
case MOMENTUM:
if(step == 1)
s1_vals[j] = (float)g_vals[j];
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
break;
case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
break;
}
}
}
__syncthreads();
Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
}
}
#define NUM8BIT 16
#define NUM_THREADS 256
#define NUM_PER_BLOCK 4096
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
const float gnorm_scale, const int n)
{
const int n_full = gridDim.x * NUM_PER_BLOCK;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK);
float g_val = 0.0f;
float local_max_s1 = -FLT_MAX;
float local_max_s2 = -FLT_MAX;
float local_unorm = 0.0f;
float s2_vals[NUM8BIT];
float s1_vals[NUM8BIT];
T g_vals[NUM8BIT];
unsigned char m_c1[NUM8BIT];
unsigned char r_c2[NUM8BIT];
typedef hipcub::BlockLoad<T, NUM_THREADS, NUM8BIT, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef hipcub::BlockReduce<float, NUM_THREADS> BlockReduce;
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadUInt8::TempStorage loadc;
typename BlockReduce::TempStorage reduce;
} temp_storage;
__shared__ float smem_quantiles1[256];
__shared__ float smem_quantiles2[256];
if(threadIdx.x < 256)
{
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x];
}
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128);
__syncthreads();
LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128);
__syncthreads();
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
g_val = g_vals[j];
g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1;
s1_vals[j] += (1.0f-beta1)*g_val;
local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j]));
}
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
g_val = g_vals[j];
g_val *= gnorm_scale;
s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2;
s2_vals[j] += (1.0f-beta2)*g_val*g_val;
local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j]));
}
if(unorm != NULL)
{
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step));
float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step));
s1_vals[j] *= correction1;
s2_vals[j] *= correction2;
float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
local_unorm += update_val*update_val;
}
}
}
__syncthreads();
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items);
__syncthreads();
local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, hipcub::Max(), valid_items);
if(unorm != NULL)
{
__syncthreads();
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items);
}
if(threadIdx.x == 0)
{
atomicMax(&new_max1[0], local_max_s1);
atomicMax(&new_max2[0], local_max_s2);
if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); }
}
}
#define NUM_PER_THREAD2 4
#define NUM_THREADS2 1024
#define NUM_PER_BLOCK2 4096
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS2, 1)
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
int valid_items = 0;
float g_val = 0.0f;
float s1_vals[NUM_PER_THREAD2];
float s2_vals[NUM_PER_THREAD2];
const float correction1 = 1.0f - powf(beta1, step);
const float correction2 = sqrtf(1.0f - powf(beta2, step));
const float step_size = -lr*correction2/correction1;
//const float step_size = -lr*correction2/correction1;
float new_max_val1 = 1.0f/new_max1[0];
float new_max_val2 = 1.0f/new_max2[0];
float update_scale = 1.0f;
if(max_unorm > 0.0f)
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
else{ update_scale = 1.0f; }
}
else{ update_scale = 1.0f; }
unsigned char c1s[NUM_PER_THREAD2];
unsigned char c2s[NUM_PER_THREAD2];
T p_vals[NUM_PER_THREAD2];
T g_vals[NUM_PER_THREAD2];
typedef hipcub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[256];
__shared__ float smem_quantiles2[256];
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadChar::TempStorage loadc;
typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh;
} temp_storage;
if(threadIdx.x < 512)
{
if(threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
else
smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256];
}
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0);
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[c1s[j]];
s1_vals[j] = s1_vals[j]*max1[0];
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1);
// make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
}
s2_vals[j] = smem_quantiles2[c2s[j]];
s2_vals[j] = s2_vals[j]*max2[0];
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2);
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
{
p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps))))));
if(weight_decay > 0.0f)
p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay));
}
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items);
__syncthreads();
}
}
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
const float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = gridDim.x * NUM_PER_BLOCK;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK);
float g_val = 0.0f;
float local_max_s1 = -FLT_MAX;
float local_unorm = 0.0f;
float s1_vals[NUM8BIT];
T g_vals[NUM8BIT];
unsigned char m_c1[NUM8BIT];
typedef hipcub::BlockLoad<T, NUM_THREADS, NUM8BIT, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef hipcub::BlockReduce<float, NUM_THREADS> BlockReduce;
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadUInt8::TempStorage loadc;
typename BlockReduce::TempStorage reduce;
} temp_storage;
__shared__ float smem_quantiles1[256];
if(threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
__syncthreads();
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128);
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
g_val = g_vals[j];
g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0];
switch(OPTIMIZER)
{
case MOMENTUM:
if(step == 1)
s1_vals[j] = (float)g_vals[j];
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
if(unorm != NULL)
local_unorm += s1_vals[j]*s1_vals[j];
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
}
local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j]));
}
}
__syncthreads();
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items);
if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); }
if(unorm != NULL)
{
__syncthreads();
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items);
if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); }
}
}
template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
int valid_items = 0;
float g_val = 0.0f;
float s1_vals[NUM_PER_THREAD2];
float new_max_val1 = 1.0f/new_max1[0];
float update_scale = 1.0f;
if(max_unorm > 0.0f)
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
else{ update_scale = 1.0f; }
}
else{ update_scale = 1.0f; }
unsigned char c1s[NUM_PER_THREAD2];
T p_vals[NUM_PER_THREAD2];
T g_vals[NUM_PER_THREAD2];
typedef hipcub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[256];
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadChar::TempStorage loadc;
typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh;
} temp_storage;
if(threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(weight_decay > 0.0f)
g_val += ((float)p_vals[j])*weight_decay;
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
switch(OPTIMIZER)
{
case MOMENTUM:
if(step == 1)
s1_vals[j] = g_vals[j];
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
break;
}
c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1);
// make sure state1 term has still the same sign after quantization
if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
}
}
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
__syncthreads();
}
}
template<typename T, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
int valid_items = 0;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
__shared__ typename BlockReduce::TempStorage reduce;
__shared__ typename LoadT::TempStorage loadT;
T vals[NUM_VALS];
float local_sum = 0.0f;
for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_sum = 0.0f;
__syncthreads();
LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f);
#pragma unroll NUM_VALS
for(int j = 0; j < NUM_VALS; j++)
local_sum += ((float)vals[j])*((float)vals[j]);
local_sum = BlockReduce(reduce).Sum(local_sum, valid_items);
if(threadIdx.x == 0)
{
if(step == 1)
{
// initialize with the same norm for all positions
//#pragma unroll 10
for(int j = 0; j < 100; j++)
atomicAdd(&gnorm_vec[j], local_sum);
}
else
atomicAdd(&gnorm_vec[step % 100], local_sum);
}
}
}
#define LANES 2
#define QUAD 3
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2,
float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n)
{
//const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
int valid_items = 0;
float g_val = 0.0f;
float s1_vals[N_PER_TH];
float s2_vals[N_PER_TH];
// 2-5%
const float correction1 = 1.0f - __powf(beta1, step);
const float correction2 = sqrtf(1.0f -__powf(beta2, step));
const float step_size = __fdividef(-lr*correction2,correction1);
const int lane_id = threadIdx.x % LANES;
float new_local_abs_max1 = -FLT_MAX;
float new_local_abs_max2 = -FLT_MAX;
float quadrants1[QUAD];
float quadrants2[QUAD];
unsigned char c1s[N_PER_TH];
unsigned char c2s[N_PER_TH];
T g_vals[N_PER_TH];
typedef hipcub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[LANES][257];
__shared__ float smem_quantiles2[LANES][257];
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2;
__shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ typename BlockReduce2::TempStorage reduce2;
__shared__ float smem_exchange1[1];
__shared__ float smem_exchange2[1];
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadChar::TempStorage loadc;
typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh;
} temp_storage;
// init: 0.2 -> 0.23
// 0.23 -> 0.23
smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x];
# pragma unroll
for(unsigned int j = 1; j < LANES; j++)
{
smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x];
}
__syncthreads();
#pragma unroll
for(int k = 0; k < QUAD; k++)
{
quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
}
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
// loads: 0.23 -> 0.85/1.44
valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
__syncthreads();
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0);
new_local_abs_max1 = -FLT_MAX;
new_local_abs_max2 = -FLT_MAX;
// update: 2.48/1.57 -> 2.51/1.60
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
}
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j]));
}
// reduce: 2.51/1.60 -> 2.67/1.69
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max());
new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, hipcub::Max());
if(threadIdx.x == 0)
{
smem_exchange1[0] = new_local_abs_max1;
smem_exchange2[0] = new_local_abs_max2;
}
__syncthreads();
if(threadIdx.x == 0)
{
absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
absmax2[i/BLOCK_SIZE] = new_local_abs_max2;
}
else
{
new_local_abs_max1 = smem_exchange1[0];
new_local_abs_max2 = smem_exchange2[0];
}
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), g_vals, valid_items, (T)0.0f);
// reduce: 2.67/1.69 -> 2.67/1.70
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
if(weight_decay > 0.0f)
g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay));
}
}
// store: 0.85/1.44 -> 2.48/1.57
__syncthreads();
StoreT(temp_storage.storeh).Store(&(p[i]), g_vals, valid_items);
// quantizaztion: 2.67/1.70 -> 3.4/3.3
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2));
// make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
}
}
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items);
}
}
#define LANES 2
#define QUAD 3
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* absmax1,
float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n)
{
//const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
int valid_items = 0;
float g_val = 0.0f;
float s1_vals[N_PER_TH];
// 2-5%
const int lane_id = threadIdx.x % LANES;
float new_local_abs_max1 = -FLT_MAX;
float quadrants1[QUAD];
unsigned char c1s[N_PER_TH];
T g_vals[N_PER_TH];
T p_vals[N_PER_TH];
typedef hipcub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[LANES][257];
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
__shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ float smem_exchange1[1];
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadChar::TempStorage loadc;
typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh;
} temp_storage;
// init: 0.2 -> 0.23
// 0.23 -> 0.23
smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
# pragma unroll
for(unsigned int j = 1; j < LANES; j++)
smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
__syncthreads();
#pragma unroll
for(int k = 0; k < QUAD; k++)
quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
// loads: 0.23 -> 0.85/1.44
valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
__syncthreads();
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
new_local_abs_max1 = -FLT_MAX;
// update: 2.48/1.57 -> 2.51/1.60
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
if(weight_decay > 0.0f)
g_val += ((float)p_vals[j])*weight_decay;
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
switch(OPTIMIZER)
{
case MOMENTUM:
if(step == 1)
s1_vals[j] = g_val;
else
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
case ADAGRAD:
s1_vals[j] = s1_vals[j] + (g_val*g_val);
break;
}
}
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
}
// reduce: 2.51/1.60 -> 2.67/1.69
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max());
if(threadIdx.x == 0)
smem_exchange1[0] = new_local_abs_max1;
__syncthreads();
if(threadIdx.x == 0)
absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
else
new_local_abs_max1 = smem_exchange1[0];
// reduce: 2.67/1.69 -> 2.67/1.70
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
switch(OPTIMIZER)
{
case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
break;
case RMSPROP:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
break;
case ADAGRAD:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
break;
}
}
}
// store: 0.85/1.44 -> 2.48/1.57
__syncthreads();
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
// quantizaztion: 2.67/1.70 -> 3.4/3.3
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
// make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
}
}
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
}
}
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols)
{
// 0. reset stats to -FLT_MAX
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
// 2. compute col max (per thread); store in smem due to register pressure
// 3. compute row max (per block); store in smem to accumulate full global mem transation
// 4. store data via atomicMax
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
const int items_per_load = ITEMS_PER_THREAD*THREADS;
typedef hipcub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, hipcub::BLOCK_LOAD_VECTORIZE> LoadT;
typedef hipcub::BlockReduce<float, THREADS> BlockRowReduce;
typedef hipcub::BlockReduce<int, THREADS> BlockRowSum;
typedef hipcub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
__shared__ union {
typename BlockExchange::TempStorage exchange;
typename BlockRowReduce::TempStorage rowreduce;
typename BlockRowSum::TempStorage rowsum;
typename LoadT::TempStorage loadt;
} temp_storage;
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
__shared__ int smem_row_nnz_values[TILE_ROWS];
half local_data[ITEMS_PER_THREAD];
float local_data_fp32[ITEMS_PER_THREAD];
float local_col_absmax_values[ITEMS_PER_THREAD];
int local_row_nnz_count = 0;
float row_absmax = -FLT_MAX;
// 0. reset stats to -FLT_MAX
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0;
}
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_col_absmax_values[j] = -FLT_MAX;
__syncthreads();
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
int i = base_idx;
// we load row after row from the base_position
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
for(int row = 0; row < TILE_ROWS; row++)
{
if(base_row+row >= rows){ break; }
local_row_nnz_count = 0;
i = base_idx + ((row)*cols);
// each thread gets data from the same column
__syncthreads();
LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f));
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data[j] = fabsf(local_data[j]);
if(SPARSE_DECOMP)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
if((float)local_data[j] >= nnz_threshold)
{
local_row_nnz_count += 1;
local_data[j] = 0.0f;
}
}
// 2. compute col max (per thread); store in smem due to register pressure
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
// take the col max for this row
// we use shared memory because register pressure is too high if we do this locally
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
// 3. compute row max (per block); store in smem to accumulate full global mem transation
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data_fp32[j] = local_data[j];
__syncthreads();
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, hipcub::Max());
if(SPARSE_DECOMP)
{
__syncthreads();
local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count);
}
// we store the data temporarily in shared memory so we
// can execute a full atomic block transaction into global memory later
// we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
if(threadIdx.x == 0)
{
smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax;
// each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
smem_row_nnz_values[row] = local_row_nnz_count;
}
__syncthreads();
}
// 4. store data via atomicMax
// to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
// into a striped arangement: [0, 8, 16, 24, ..] for t0
__syncthreads();
BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values);
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
if(base_col+threadIdx.x+(j*THREADS) < cols)
{
float val = colStats[base_col+(threadIdx.x+(j*THREADS))];
if(val < local_col_absmax_values[j])
atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]);
}
for(int j = 0; j < ITEMS_PER_THREAD; j++)
if(base_row+threadIdx.x+(j*THREADS) < rows)
{
float val = rowStats[base_row+(threadIdx.x+(j*THREADS))];
if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)])
atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]);
}
if(SPARSE_DECOMP)
if(threadIdx.x < TILE_ROWS)
nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x];
}
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 0>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n)
{
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive
// since different row/col stats need to be loaded with each thread.
// (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
// and would lead to low global load utilization.
// (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
// for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
// (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
// This allows for efficient row/col loading from shared memory within the tile.
// We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
// the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
// we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
// shared memory loads.
// data is in 32 column-tile major with tile width 32 columns and numRows rows
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
// L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
// C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register))
// C2. Compute normalization values and store col values in register
// S1. Store C1 into 16-bit output
// S2. Store col/row statistics of new buffer in shared memory
// We allow for sub-tiles to span multiple col32 tiles. This is okay
// since the items per thread only rely on a single column statistic.
const int n_out = numRows*numCols;
int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
// we have tiles of size numRows*32, thus col only increases every numRows
// num_row_tiles is the tiles after which the column increases by 32
// blockIdx.x is the index of the current tile
int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
// base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
// SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
// subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
// Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads.
// For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have
// 1024*1024/(128*32) = 256 tiles
// 256 tiles are 256*128*32/4 = 256*1024 threads
// 1. Figure out how index relates to the start of the sub-tile
// 2. Each thread < SUBTILE_ROWS calculates row index
// 3. Load striped and store in shared memory
int local_values[ITEMS_PER_THREAD];
half local_output[ITEMS_PER_THREAD];
float local_rowStats[ITEMS_PER_THREAD];
__shared__ float smem_rowStats[SUBTILE_ROWS];
typedef hipcub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, hipcub::BLOCK_LOAD_DIRECT> LoadInt32;
typedef hipcub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
__shared__ typename LoadInt32::TempStorage loadint32;
__shared__ typename ExchangeInt32::TempStorage exchangeint32;
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
float colStat = col >= numCols ? 0.0f : colStats[col];
float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]);
// no block loads for rows for now -- keep it simple
for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
{
// todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
int row = (base_row+j) % numRows; // wrap around
// each warp accesses the same element, for four consequitive elements
// todo: update description about striped shared memory, it is not needed
// rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
smem_rowStats[j] = rowStats[row];
}
__syncthreads();
// each block processes SUBTILE_ROWS*32 elements
const int items_per_load = THREADS*ITEMS_PER_THREAD;
const int rows_per_load = items_per_load/32;
int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
int row_offset = 0;
// subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed
int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32);
for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load)
{
int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset);
int valid_items = valid_rows*32;
if(valid_items <= 0) // the sub-tile might have more elements than the tile itself
break;
// L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0);
ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values);
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue);
//absmax_col = fmax(fabsf(local_output[j]), absmax_col);
// we store data in row major
// to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3]
// so that each thread holds ITEMS_PER_THREAD consecutive items for each row
// this way throughput into storage is increased by a factor of ~2x
// for now we use a simple store
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols);
if(outIdx< n_out && col < numCols)
out[outIdx] = local_output[j];
}
row_offset += rows_per_load;
}
}
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
{
// assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
// Each thread reads the same column but multiple rows
// Rows are loaded in shared memory and access is shared across the threadblock (broadcast)
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
// 2. quantize data with row/col stats
// 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
const int items_per_load = ITEMS_PER_THREAD*THREADS;
typedef hipcub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, hipcub::BLOCK_LOAD_VECTORIZE> LoadHalf;
__shared__ typename LoadHalf::TempStorage loadhalf;
typedef hipcub::BlockStore<char, THREADS, ITEMS_PER_THREAD, hipcub::BLOCK_STORE_VECTORIZE> StoreInt8;
__shared__ typename StoreInt8::TempStorage storeint8;
__shared__ float smem_row_stats[TILE_ROWS];
__shared__ unsigned int smem_nnz_row_idx[TILE_ROWS];
half local_data[ITEMS_PER_THREAD];
float local_col_stats[ITEMS_PER_THREAD];
char local_quantized_data[ITEMS_PER_THREAD];
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols)
local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]);
for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x)
{
if(base_row + i < rows)
smem_row_stats[i] = rowStats[base_row+i];
if(SPARSE_DECOMP)
smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i];
}
__syncthreads();
// we load row after row from the base_position
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
for(int row = 0; row < TILE_ROWS; row++)
{
if(base_row + row >= rows){ break; }
int i = base_idx + (row*cols);
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f);
float row_stat = __fdividef(127.0f, smem_row_stats[row]);
// 2. quantize data with row/col stats
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
// we already pre-normalized the col/row stat:
// what this does is float/absmax*127 = int8
if(SPARSE_DECOMP)
{
if(fabsf((float)local_data[j]) >= threshold)
{
local_quantized_data[j] = 0;
int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX);
rowidx[old_idx] = base_row+row;
colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j;
val[old_idx] = local_data[j];
}
else
{
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
}
}
else
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
}
StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items);
// 2. quantize data with row/col stats
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
// we already pre-normalized the col/row stat:
// what this does is float/absmax*127 = int8
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j]));
}
__syncthreads();
StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items);
}
}
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
{
// 0. Load data into 32*32 shared memory tiles
// 1. transpose / reorder in shared memory
// 2. store
// COL32 FORMAT:
// rows*32 tiles
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
// AMPERE FORMAT:
// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
// To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
// As such we need:
// at least 32*4 shared memory tiles for col32; preferably 32*32
// at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
// at least 32*8 shared memory tiles for col4_turing: preferably 32*32
// for efficient loading of row major we need to load 128 elements and repeat this 32 items
// this would imply a 32x128 shared memory tile -> 4kb
// It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
// we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
// for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
// register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
//
// to make the shared memory work with that occupancy we might need to union the block loads/stores
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
// we load 128 bytes per warp with
// 32 rows for transposes that fill col32 types
// so that we can have contiguous stores
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
char local_data[ITEMS_PER_THREAD];
typedef hipcub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
// we load row after row from the base_position
// Load data row by row
int warps = blockDim.x/32;
int warp_id = threadIdx.x/32;
int warp_lane = threadIdx.x % 32;
int offset = 0;
int smem_row = 0;
// each warp loads one row of 128 bytes
for(int row = warp_id; row < TILE_ROWS; row+=warps)
{
int i = base_idx + (row*cols);
// we load up to 128 bytes/items per load
int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col;
// 0. Load data into 32*32 shared memory tiles
if(base_row + row < rows)
{
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
int col_idx = warp_lane+(j*32);
if(col_idx < valid_items)
local_data[j] = A[i+col_idx];
else
local_data[j] = 0;
}
}
else
{
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data[j] = 0;
}
if(TRANSPOSE)
{
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
int local_col = (32*j)+warp_lane;
//int local_row = row;
// store as 256x32
smem_data[(local_col*33) + row] = local_data[j];
}
}
else
{
// treat smem as 32x256, that is 32 rows and 256 columns
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j];
}
smem_row += warps;
// 1. transpose / reorder in shared memory
if(smem_row % 32 == 0)
{
smem_row = 0;
__syncthreads();
for(int subrow = warp_id; subrow < 32; subrow+=warps)
{
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
switch(FORMAT)
{
case COL32:
if(TRANSPOSE)
{
// data lies in shared memory in the following way:
// row0 [col0 col1 ... col31]
// row1 [col0 col1 ... col31]
// ...
//
// As such we read consequtive entries with 256 threads (8rows x 32 columns)
// as j increase, the row increase by a factor of 8
// We load 8 rows per subrow loop, and subrow increase by 8 per loop
// so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
{
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
// each 32 columns we have new tile
// each tile has size outRows*32 and base_row is done in increments of 32
offset = base_row*outRows;
out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
}
}
else
{
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
{
offset = (base_col/32)*(32*rows);
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data;
}
}
break;
case COL_TURING:
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
//
// [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
if(TRANSPOSE)
{
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
{
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
// each 32 columns we have new tile
// each tile has size 8*32 = 256 elements offset
// for each row offset of 8 we increaes the tile first
// after all rows are exhausted, we increase the col
int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
// we increase by row_tile_column every 32 columns
// base_row increase in increments of 32
//int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
//int col_offset = (base_row/32)*row_tile_column;
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
// 256*outRows/8*base_row/32 = outRows*base_row
int col_offset = outRows*base_row;
offset = row_offset+col_offset;
// since we process even number of rows with each j (8) and with each subrow (8j) we can determine
// odd or even rows with the warp_id (each warp processes one row)
// the col is warp_lane (max 32 columns per row) and the row warp_id
if(warp_id % 2 == 1)
// odd
offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2);
else
// even
offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2);
out[offset] = data;
}
}
else
{
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
{
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
// set offset designates the tile offset among the 8*32 tiles
// we first increase rows and then columns. Since we load 128 columns at once
// we increase the offset by outRows*32 every 32 columns
// additionally, we increase the offset by 8*32=256 every 8 rows
offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile)
// first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
// each of these has 32 values in total for 32*4 = 128 as offset if odd
// every set of 4 columns increases the total offset by 16
// each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
// this happends every 8 rows anew (subrow % 8)
// one writes 4 columns at once that is (col % 4) for the particular index in the subtile
int subcol = warp_lane;
// add local offset (4x4 sub-tile)
if(subrow % 2 == 1)
// odd
offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2);
else
// even
offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2);
out[offset] = data;
}
}
break;
case COL_AMPERE:
// AMPERE FORMAT:
// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
if(TRANSPOSE)
{
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
{
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
// each 32 columns we have new tile
// each tile has size 32*32 = 1024 elements offset
// for each row offset of 32 we increaes the tile first
// after all rows are exhausted, we increase the col
int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
// we increase by row_tile_column every 32 columns
// base_row increase in increments of 32
//int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
//int col_offset = (base_row/32)*row_tile_column;
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
// 1024*outRows/32*base_row/32 = outRows*base_row
int col_offset = outRows*base_row;
offset = row_offset+col_offset;
// same as in the non-transpose case (see below)
// the difference is that now rows = cols
// in this case warp_id = subrow
// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
// every 2 rows, the offset increases by two [0, 1, 8, 9...]
// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset
int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2);
// global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
out[offset + (ampere_row*32) + warp_lane] = data;
}
}
else
{
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
{
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
// set offset designates the tile offset among the 32*32 tiles
// we first increase rows and then columns. Since we load 128 columns at once
// we increase the offset by outRows*32 every 32 columns
// additionally, we increase the offset by 32*32=1024 every 32 rows
offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile)
// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
// every 2 rows, the offset increases by two [0, 1, 8, 9...]
// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2);
// global offset + row with 32 cols each + 32 cols per j + col_idx
out[offset + (local_row*32) + warp_lane] = data;
}
}
break;
}
}
}
}
}
}
#define C 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
// If a block finishes, the next one is scheduled. Since the last blocks like have fewer
// elements they finish faster "fillin up" the gaps left by larger blocks
// without tensor cores
// 1. use rowidx_length to find what to load (as many blocks as there are rows)
// 2. Load A into registers
// 3. each warp loads all required rows of B but each warp is offset by k
// 4. Do mma operations that accumulate into registers
// 5. Each warp stores its output row into matrix C
const int count = max_count[blockIdx.x];
const int local_max_idx = max_idx[blockIdx.x];
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
const int local_row_idx = rowidx[offset];
const int warp_id = threadIdx.x / 32;
const int warp_idx = threadIdx.x % 32;
const int warp_offset = (warp_id*32)*SPMM_ITEMS;
const int num_items = BITS == 8 ? 8 : 8;
int idx_col_B = warp_offset;
int local_idx_col_B_offset = 0;
half local_valA[MAX_SPARSE_COUNT];
int local_colidxA[MAX_SPARSE_COUNT];
half local_valC[SPMM_ITEMS];
T local_valsB[num_items];
half local_valOut[num_items];
// 128 byte loads per warp == 4 bytes per thread
// 2. Load A into registers
for(int j = 0; j < MAX_SPARSE_COUNT; j++)
{
local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f);
local_colidxA[j] = j < count ? colidx[offset+j] : 0;
}
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
// we expect each warp to be SPMM_ITEMS*32 apart
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
// added 3 bytes = 6 values between warps should reduce bank conflicts
__shared__ half smem_dequant_stats[SMEM_SIZE];
while(idx_col_B < colsB)
{
if(dequant_stats != NULL)
{
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset];
__syncthreads();
}
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j++)
local_valC[j] = 0.0f;
#pragma unroll
for(int i = 0; i < count; i++)
{
// 3. each warp loads all required rows of B but each warp is offset by k
int row_offset = colsB*local_colidxA[i];
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j+=num_items)
{
// 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j;
if(idx >= colsB){ break; }
//printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx);
if((idx+num_items < colsB))
{
if(BITS == 8)
reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float2*>(B)[(row_offset+ idx)/num_items];
else
reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float4*>(B)[(row_offset+ idx)/num_items];
}
else
{
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
if(idx+k < colsB)
local_valsB[k] = B[row_offset+idx+k];
else
local_valsB[k] = 0.0f;
}
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
{
//if((float)local_valsB[k] != 0.0)
// printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB);
if(BITS == 8 && dequant_stats != NULL)
// we do texture cache reads (__ldg) on dequant_stats which should be super fast
{
float valB = local_valsB[k];
float valA = local_valA[i];
if(valB != 0.0 && valA != 0.0)
local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA;
}
else
local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i];
}
}
}
int idx_row_C = (colsB*local_row_idx);
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j+=num_items)
{
//int idx_col_C = idx_col_B + (32*j) + warp_idx;
int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j;
int idx_val = idx_col_C + idx_row_C;
if(idx_col_C +num_items < colsB)
{
// load outputs to do inplace addition
reinterpret_cast<float4(&)[num_items/4]>(local_valOut)[0] = reinterpret_cast<float4*>(out)[idx_val/num_items];
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
}
else
{
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
if(idx_col_C + k < colsB)
out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k];
}
}
idx_col_B += blockDim.x*SPMM_ITEMS;
local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
}
}
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
{
int local_colidx = idx[blockIdx.x];
if(FORMAT==COL_TURING)
{
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
// columns are grouped in increments of 4, meaning that one has the following rows and columns
// rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
// cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...]
// each thread reads 1 element = 1 row
for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
{
int offset_per_col_tile = ((rowsA+7)/8)*32*8;
int tile_offset_rows = (row/8)*32*8;
int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
int offset = 0;
int subtile_col_idx = local_colidx%32;
int subtile_row_idx = row % 8;
if(row % 2 == 1)
offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2);
else
// even
offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2);
offset += tile_offset_rows + tile_offset_cols;
char val = A[offset];
int out_idx = (row*idx_size) + blockIdx.x;
out[out_idx] = val;
}
}
else if(FORMAT == COL_AMPERE)
{
for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
{
// we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
// within each tile.
int offset_per_col_tile = ((rowsA+31)/32)*32*32;
int tile_offset_rows = (row/32)*32*32;
int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
int subtile_col_idx = local_colidx%32;
int subtile_row_idx = row % 32;
// this magic is taken from the cublasLt doc (search for COL32)
int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
offset += tile_offset_cols + tile_offset_rows;
char val = A[offset];
int out_idx = (row*idx_size) + blockIdx.x;
out[out_idx] = val;
}
}
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n);
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n);
template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n);
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \
const float beta1, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float)
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float* state2, float *unorm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit2State(ADAM, half)
MAKE_PreconditionOptimizer32bit2State(ADAM, float)
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
float *unorm, \
const float beta1, \
const float eps, const int step, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
const float weight_decay, \
const float gnorm_scale, \
const int n); \
MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
MAKE_PreconditionStatic8bit1State(RMSPROP, half)
MAKE_PreconditionStatic8bit1State(RMSPROP, float)
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
float weight_decay, \
const float gnorm_scale, \
const int n); \
MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State(MOMENTUM, float)
MAKE_optimizerStatic8bit1State(RMSPROP, half)
MAKE_optimizerStatic8bit1State(RMSPROP, float)
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
float *unorm, \
const float beta1, const float beta2, \
const float eps, const int step, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
const float gnorm_scale, \
const int n); \
MAKE_PreconditionStatic8bit2State(ADAM, half)
MAKE_PreconditionStatic8bit2State(ADAM, float)
#define MAKE_optimizerStatic8bit2State(oname, gtype) \
template __global__ void kOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, \
const float gnorm_scale, \
const int n); \
MAKE_optimizerStatic8bit2State(ADAM, half)
MAKE_optimizerStatic8bit2State(ADAM, float)
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 256, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 256, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 128, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 128, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 64, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 64, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 256, 128, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 256, 128, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 128, 64, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 128, 64, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 64, 64, 1>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 64, 64, 1>(float *code, unsigned char * A, float * absmax, float *out, const int n);
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
const float beta1, const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* absmax1, float* absmax2, \
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
gtype* p, gtype* __restrict__ const g, unsigned char* state1, \
const float beta1, const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* absmax1, \
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <hip/hip_runtime.h>
#include <float.h>
#include "ops.cuh"
#ifndef kernels
#define kernels
template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
const float weight_decay,
const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
float weight_decay, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* absmax1,
float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n);
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
#endif
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <hip/hip_runtime.h>
#include "ops.cuh"
#include "kernels.cuh"
// #include <hipcub/device/device_scan.cuh>
#include <limits>
// #include <BinSearch.h>
#include <AAlloc.h>
#include <BinAlgo.h>
#include <cassert>
// #include <common.h>
using namespace BinSearch;
using std::cout;
using std::endl;
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
{
int threads = 512;
int num_blocks = n/threads;
num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1;
kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
{
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float)));
kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void quantize(float *code, float *A, unsigned char *out, int n)
{
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void dequantize(float *code, unsigned char *A, float *out, int n)
{
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
{
int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
if(STOCHASTIC == 1)
assert(blocksize == 4096);
if(blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 2048)
kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 1024)
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 512)
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 256)
kQuantizeBlockwise<T, 256, 2, 0><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 128)
kQuantizeBlockwise<T, 128, 2, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 64)
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
{
int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
if(blocksize == 4096)
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
else if(blocksize == 2048)
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
else if(blocksize == 1024)
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
else if(blocksize == 512)
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
else if(blocksize == 256)
kDequantizeBlockwise<T, 256, 128, 2><<<num_blocks, 256/2>>>(code, A, absmax, out, n);
else if(blocksize == 128)
kDequantizeBlockwise<T, 128, 64, 2><<<num_blocks, 128/2>>>(code, A, absmax, out, n);
else if(blocksize == 64)
kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
switch(OPTIMIZER)
{
case ADAM:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
}
}
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm,
float beta1, float beta2,
float eps, int step, float lr,
float* quantiles1, float* quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay,
const float gnorm_scale, int n)
{
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); }
switch(OPTIMIZER)
{
case ADAM:
CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
default:
break;
}
}
#define BLOCKSIZE_2STATE 2048
#define NUM_2STATE 8
#define BLOCKSIZE_1STATE 2048
#define NUM_1STATE 8
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{
int num_blocks = 0;
switch(OPTIMIZER)
{
case ADAM:
num_blocks = n/BLOCKSIZE_2STATE;
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
num_blocks = n/BLOCKSIZE_1STATE;
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
}
}
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{
int num_blocks = n/2048;
num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
{
cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
cout << "=============================================" << endl;
cout << "" << endl;
assert(false);
return ;
// const int falpha = 1;
// const int fbeta = 0;
// const void * alpha = &falpha;
// const void * beta = &fbeta;
// hipblasStatus_t status;
// status = hipblasGemmEx(context->m_handle,
// transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
// transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
// m, n, k,
// alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta,
// C, HIPBLAS_R_32I, ldc,
// HIPBLAS_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
// if (status != HIPBLAS_STATUS_SUCCESS)
// {
// std::cout << "CUBLAS ERROR: Status " << status << std::endl;
// }
}
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount)
{
const int falpha = 1;
const int fbeta = 0;
const void * alpha = &falpha;
const void * beta = &fbeta;
hipblasStatus_t status;
//cout << transposeA << transposeB << endl;
//printf("%i %i %i\n", m,n,k);
//printf("%i %i %i\n", lda,ldb,ldc);
//printf("%i %i %i\n", strideA, strideB, strideC);
//printf("%i\n", batchCount);
status = hipblasGemmStridedBatchedEx(context->m_handle,
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
m, n, k,
alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta,
C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount,
HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
if (status != HIPBLAS_STATUS_SUCCESS)
{
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
}
}
int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}
template<int ORDER> int get_leading_dim(int dim1, int dim2)
{
switch(ORDER)
{
case ROW:
return dim2;
break;
case COL:
return dim1;
break;
case COL32:
// 32*row tiles
return dim1*32;
break;
case COL_TURING:
return 32*roundoff(dim1, 8);
break;
case COL_AMPERE:
// 32*32 tiles
return 32*roundoff(dim1, 32);
break;
default:
return 0;
break;
}
}
template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2);
template int get_leading_dim<COL32>(int dim1, int dim2);
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
{
cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
cout << "=============================================" << endl;
cout << "" << endl;
assert(false);
}
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
cout << "=============================================" << endl;
cout << "" << endl;
assert(false);
return 0;
}
int fill_up_to_nearest_multiple(int value, int multiple)
{
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
}
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols)
{
int threads = 512;
int tileCols = fill_up_to_nearest_multiple(numCols, 32);
int n = numRows*tileCols;
int subtile_rows = 128;
int tilesize = 32*subtile_rows;
int num_blocks = numRows/subtile_rows;
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
num_blocks = num_blocks*(tileCols/32);
assert(threads <= tilesize);
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
#define STATS_THREADS 64
#define STATS_ITEMS 4
#define STATS_ROWS 16
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
{
int tile_cols = STATS_THREADS*STATS_ITEMS;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
int row_tiles = (tiledRows/STATS_ROWS);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
else if(nnz_threshold != 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols)
{
int threads = 64;
int items_per_thread = 4;
int tile_cols = threads*items_per_thread;
int tile_rows = 16;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
if(threshold > 0.0f)
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
else
kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
{
int threads = 256;
int items_per_thread = 8;
// we load 128 column values per warp
int tile_cols = 32*items_per_thread;
int tile_rows = 32;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32);
if(FORMAT == COL_TURING)
{
if(TRANSPOSE)
outRows = fill_up_to_nearest_multiple(cols, 8);
else
outRows = fill_up_to_nearest_multiple(rows, 8);
}
else if(FORMAT == COL_AMPERE)
{
if(TRANSPOSE)
outRows = fill_up_to_nearest_multiple(cols, 32);
else
outRows = fill_up_to_nearest_multiple(rows, 32);
}
else
{
if(TRANSPOSE)
{
outCols = fill_up_to_nearest_multiple(rows, 32);
outRows = cols;
}
}
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{
cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
cout << "=============================================" << endl;
cout << "" << endl;
assert(false);
return;
}
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{
kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols)
{
int threads = 256;
// we load 128 column values per warp
int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32);
int tiledRows = 0;
int num_blocks = idx_size;
if(FORMAT == COL_TURING)
{
tiledRows = fill_up_to_nearest_multiple(rows, 8);
}
else if(FORMAT == COL_AMPERE)
{
tiledRows = fill_up_to_nearest_multiple(rows, 32);
}
kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template int igemmlt<COL_TURING, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_TURING, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_TURING, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_AMPERE, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_AMPERE, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_AMPERE, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_TURING, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_TURING, 1>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_AMPERE, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows, int cols);
template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n);
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float)
MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)
#define MAKE_optimizerStatic8bit(name, gtype) \
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
float *unorm, float max_unorm, float param_norm, \
float beta1, float beta2, \
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, \
const float gnorm_scale, int n); \
MAKE_optimizerStatic8bit(ADAM, half)
MAKE_optimizerStatic8bit(ADAM, float)
MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit(MOMENTUM, float)
MAKE_optimizerStatic8bit(RMSPROP, half)
MAKE_optimizerStatic8bit(RMSPROP, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(float, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#ifndef ops_H
#define ops_H
#include <stdio.h>
#include <iostream>
#include <unistd.h>
#include <assert.h>
#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
#include <hipblas/hipblas.h>
// #include <cublasLt.h>
#include <hipsparse/hipsparse.h>
#include <vector>
#include <functional>
typedef struct cublasLtContext* cublasLtHandle_t;
#define CUDA_CHECK_RETURN(value) { \
hipError_t _m_cudaStat = value; \
if (_m_cudaStat != hipSuccess) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} }
#define THREADS_PER_BLOCKS (512)
#define CHECK_CUSPARSE(value) { \
hipsparseStatus_t _m_cudaStat = value; \
if (_m_cudaStat != HIPSPARSE_STATUS_SUCCESS) { \
fprintf(stderr, "Error <sparse error> at line %d in file %s\n", \
__LINE__, __FILE__); \
exit(1); \
} }
#define THREADS_PER_BLOCKS (512)
inline void checkCudaStatus(hipError_t status) {
if (status != hipSuccess) {
printf("cuda API failed with status %d: %s\n", status, hipGetErrorString(status));
throw std::logic_error("cuda API failed");
}
}
inline int checkCublasStatus(hipblasStatus_t status) {
if (status != HIPBLAS_STATUS_SUCCESS) {
printf("cuBLAS API failed with status %d\n", status);
//throw std::logic_error("cuBLAS API failed");
return 1;
}
return 0;
}
typedef enum Operations_t
{
ksmul = 0,
} Operations_t;
typedef enum Optimizer_t
{
ADAM = 0,
MOMENTUM = 1,
RMSPROP = 2,
LARS = 3,
ADAGRAD = 4,
} Optimizer_t;
typedef enum Transform_t
{
ROW = 0,
COL = 1,
COL32 = 2,
COL_TURING = 3,
COL_AMPERE = 4,
} Transform_t;
class Context
{
public:
hipblasHandle_t m_handle;
Context()
{
hipblasHandle_t handle;
hipblasCreate(&handle);
m_handle = handle;
}
};
// class ContextLt
// {
// public:
// cublasLtHandle_t m_handle;
// ContextLt()
// {
// cublasLtHandle_t handle;
// cublasLtCreate(&handle);
// m_handle = handle;
// }
// };
class ContextCusparse
{
public:
hipsparseHandle_t m_handle;
ContextCusparse()
{
hipsparseHandle_t handle;
hipsparseCreate(&handle);
m_handle = handle;
}
};
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
void quantize(float *code, float *A, unsigned char *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n);
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
float beta1, float beta2, float eps, float weight_decay,
int step, float lr, const float gnorm_scale, bool skip_zeros, int n);
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm,
float beta1, float beta2,
float eps, int step, float lr,
float* quantiles1, float* quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay,
const float gnorm_scale, int n);
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
bool skip_zeros, int n);
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount);
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols);
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
#endif
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#if BUILD_CUDA
#include <ops.cuh>
#endif
#include <cpu_ops.h>
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
// maintain all that boilerplate
//===================================================================================
// UNMANGLED CALLS
//===================================================================================
#if BUILD_CUDA
void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); }
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
void fname##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
MAKE_FUNC32(adam, ADAM, float, 32)
MAKE_FUNC32(adam, ADAM, half, 16)
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
float *unorm, float max_unorm, float param_norm, \
float beta1, float beta2, \
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, float gnorm_scale, int n) \
{ \
optimizerStatic8bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
} \
MAKE_FUNC8(adam, ADAM, float, 32)
MAKE_FUNC8(adam, ADAM, half, 16)
MAKE_FUNC8(momentum, MOMENTUM, float, 32)
MAKE_FUNC8(momentum, MOMENTUM, half, 16)
MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
MAKE_BLOCKWISE8(adam, ADAM, half, 16)
MAKE_BLOCKWISE8(adam, ADAM, float, 32)
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform<dtype, src, target, transpose, bits>(ltHandle, A, out, dim1, dim2); \
} \
MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8);
MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8);
MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8);
MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32);
MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8);
MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8);
MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8);
MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32);
void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 0>(A, out, rows, cols); }
void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 1>(A, out, rows, cols); }
void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 0>(A, out, rows, cols); }
void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 1>(A, out, rows, cols); }
void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }
void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }
int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_AMPERE, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_AMPERE, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_AMPERE, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive<half, 16>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
#endif
extern "C"
{
#if BUILD_CUDA
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_CFUNC32(adam, float, 32)
MAKE_CFUNC32(adam, half, 16)
MAKE_CFUNC32(momentum, float, 32)
MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16)
#define MAKE_CFUNC8(name, gtype, gbits) \
void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
float *unorm, float max_unorm, float param_norm, \
float beta1, float beta2, \
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, float gnorm_scale, int n) \
{ \
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
} \
MAKE_CFUNC8(adam, float, 32)
MAKE_CFUNC8(adam, half, 16)
MAKE_CFUNC8(momentum, float, 32)
MAKE_CFUNC8(momentum, half, 16)
MAKE_CFUNC8(rmsprop, float, 32)
MAKE_CFUNC8(rmsprop, half, 16)
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
MAKE_CBLOCKWISE8(adam, ADAM, half, 16)
MAKE_CBLOCKWISE8(adam, ADAM, float, 32)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
{ gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); }
void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long strideA, long strideB, long strideC, int batchCount)
{ strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); }
Context *get_context(){ return new Context(); }
ContextCusparse *get_cusparse(){ return new ContextCusparse(); }
int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
//{ (cublasLtHandle_t)context->m_handle; return 0; }
//{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \
} \
MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8)
MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32)
MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8)
MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols)
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); }
void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
{ getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); }
void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols)
{ doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); }
void ctransform_row2col32(char * A, char *out, int rows, int cols)
{ transform_row2col32(A, out, rows, cols); }
void ctransform_row2col32T(char * A, char *out, int rows, int cols)
{ transform_row2col32T(A, out, rows, cols); }
void ctransform_row2turing(char * A, char *out, int rows, int cols)
{ transform_row2turing(A, out, rows, cols); }
void ctransform_row2turingT(char * A, char *out, int rows, int cols)
{ transform_row2turingT(A, out, rows, cols); }
void ctransform_row2ampere(char * A, char *out, int rows, int cols)
{ transform_row2ampere(A, out, rows, cols); }
void ctransform_row2ampereT(char * A, char *out, int rows, int cols)
{ transform_row2ampereT(A, out, rows, cols); }
void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{ spmm_coo((hipsparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); }
void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
}
URL92=https://developer.nvidia.com/compute/cuda/9.2/Prod2/local_installers/cuda_9.2.148_396.37_linux
URL100=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux
URL101=https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run
URL102=https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run
URL110=https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda_11.0.3_450.51.06_linux.run
URL111=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
URL112=https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run
URL113=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run
URL114=https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run
URL115=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run
URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run
URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
CUDA_VERSION=$1
BASE_PATH=$2
if [[ -n "$CUDA_VERSION" ]]; then
if [[ "$CUDA_VERSION" -eq "92" ]]; then
URL=$URL92
FOLDER=cuda-9.2
elif [[ "$CUDA_VERSION" -eq "100" ]]; then
URL=$URL100
FOLDER=cuda-10.0
elif [[ "$CUDA_VERSION" -eq "101" ]]; then
URL=$URL101
FOLDER=cuda-10.1
elif [[ "$CUDA_VERSION" -eq "102" ]]; then
URL=$URL102
FOLDER=cuda-10.2
elif [[ "$CUDA_VERSION" -eq "110" ]]; then
URL=$URL110
FOLDER=cuda-11.0
elif [[ "$CUDA_VERSION" -eq "111" ]]; then
URL=$URL111
FOLDER=cuda-11.1
elif [[ "$CUDA_VERSION" -eq "112" ]]; then
URL=$URL112
FOLDER=cuda-11.2
elif [[ "$CUDA_VERSION" -eq "113" ]]; then
URL=$URL113
FOLDER=cuda-11.3
elif [[ "$CUDA_VERSION" -eq "114" ]]; then
URL=$URL114
FOLDER=cuda-11.4
elif [[ "$CUDA_VERSION" -eq "115" ]]; then
URL=$URL115
FOLDER=cuda-11.5
elif [[ "$CUDA_VERSION" -eq "116" ]]; then
URL=$URL116
FOLDER=cuda-11.6
elif [[ "$CUDA_VERSION" -eq "117" ]]; then
URL=$URL117
FOLDER=cuda-11.7
elif [[ "$CUDA_VERSION" -eq "118" ]]; then
URL=$URL118
FOLDER=cuda-11.8
else
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
fi
else
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
fi
FILE=$(basename $URL)
if [[ -n "$CUDA_VERSION" ]]; then
echo $URL
echo $FILE
wget $URL
bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc
echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc
source ~/.bashrc
else
echo ""
fi
#!/bin/bash
BASE_PATH=$1
echo "MAKE SURE LD_LIBRARY_PATH IS EMPTY!"
echo $LD_LIBRARY_PATH
if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
module unload cuda
module unload gcc
rm -rf dist build
make cleaneggs
make cleanlibs
make clean
export CUDA_HOME=
export CUDA_VERSION=
make cpuonly CUDA_VERSION="CPU"
if [ ! -f "./bitsandbytes/libbitsandbytes_cpu.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.0
make cuda110 CUDA_VERSION=110
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.1
make cuda11x CUDA_VERSION=111
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.2
make cuda11x CUDA_VERSION=112
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda112.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.3
make cuda11x CUDA_VERSION=113
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda113.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.4
make cuda11x CUDA_VERSION=114
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.5
make cuda11x CUDA_VERSION=115
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.6
make cuda11x CUDA_VERSION=116
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda116.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.7
make cuda11x CUDA_VERSION=117
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.8
make cuda11x CUDA_VERSION=118
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-10.2
make cuda10x_nomatmul CUDA_VERSION=102
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.0
make cuda110_nomatmul CUDA_VERSION=110
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.1
make cuda11x_nomatmul CUDA_VERSION=111
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.2
make cuda11x_nomatmul CUDA_VERSION=112
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda112_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.3
make cuda11x_nomatmul CUDA_VERSION=113
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda113_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.4
make cuda11x_nomatmul CUDA_VERSION=114
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.5
make cuda11x_nomatmul CUDA_VERSION=115
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.6
make cuda11x_nomatmul CUDA_VERSION=116
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda116_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.7
make cuda11x_nomatmul CUDA_VERSION=117
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-11.8
make cuda11x_nomatmul CUDA_VERSION=118
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
python -m build
python -m twine upload dist/* --verbose
name: 8-bit
channels:
- conda-forge
dependencies:
- python=3.9
- pytest
- pytorch
- torchaudio
- torchvision
- cudatoolkit=11.1
- typer
- ca-certificates
- certifi
- openssl
# No kernel image available
This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. So solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``?
If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation.
__If you encounter any other error not listed here please create an issue. This will help resolve your problem and will help out others in the future.
# fatbinwrap
This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your $PATH and $LD_LIBRARY_PATH variable. In the conda base environment you can find the library under:
```bash
ls $CONDA_PREFIX/lib/*cudart*
```
Make sure this path is appended to the `LD_LIBRARY_PATH` so bnb can find the CUDA runtime environment library (cudart).
If this does not fix the issue, please try [compilation from source](compile_from_source.md) next.
If this does not work, please open an issue and paste the printed environment if you call `make` and the associated error when running bnb.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment