"vscode:/vscode.git/clone" did not exist on "da15fdb9caa05904eb27844e69af4b76af2aca46"
Unverified Commit 8823cc48 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #5310 from hpcaitech/feature/npu

Feature/npu
parents bce9499e 73f4dc57
...@@ -7,8 +7,8 @@ from typing import Dict ...@@ -7,8 +7,8 @@ from typing import Dict
import torch import torch
from torch import Tensor from torch import Tensor
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device
__all__ = ["BaseGradScaler"] __all__ = ["BaseGradScaler"]
...@@ -23,7 +23,7 @@ class BaseGradScaler(ABC): ...@@ -23,7 +23,7 @@ class BaseGradScaler(ABC):
def __init__(self, initial_scale: float, verbose: bool): def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0 assert initial_scale > 0
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float) self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float)
self._verbose = verbose self._verbose = verbose
if self._verbose: if self._verbose:
......
...@@ -5,7 +5,7 @@ from typing import Optional ...@@ -5,7 +5,7 @@ from typing import Optional
import torch import torch
from colossalai.utils.device import get_current_device from colossalai.accelerator import get_accelerator
from .base_grad_scaler import BaseGradScaler from .base_grad_scaler import BaseGradScaler
...@@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler):
hysteresis: int = 2, hysteresis: int = 2,
verbose: bool = False, verbose: bool = False,
): ):
a = get_accelerator()
a.device_count()
super().__init__(initial_scale, verbose) super().__init__(initial_scale, verbose)
if min_scale: if min_scale:
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float) self._min_scale = torch.tensor(
[min_scale], device=get_accelerator().get_current_device(), dtype=torch.float
)
else: else:
self._min_scale = None self._min_scale = None
if max_scale: if max_scale:
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float) self._max_scale = torch.tensor(
[max_scale], device=get_accelerator().get_current_device(), dtype=torch.float
)
else: else:
self._max_scale = None self._max_scale = None
...@@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler):
return state_dict return state_dict
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].to(get_current_device()) self._scale = state_dict["scale"].to(get_accelerator().get_current_device())
self._growth_factor = state_dict["growth_factor"] self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"] self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"] self._hysteresis = state_dict["hysteresis"]
...@@ -5,8 +5,8 @@ import torch ...@@ -5,8 +5,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base import MixedPrecisionMixin from .base import MixedPrecisionMixin
...@@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin): ...@@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
max_scale=max_scale, max_scale=max_scale,
) )
self.optim_state = OptimState.UNSCALED self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
@property @property
def loss_scale(self) -> float: def loss_scale(self) -> float:
......
...@@ -4,10 +4,10 @@ from typing import Dict, Tuple ...@@ -4,10 +4,10 @@ from typing import Dict, Tuple
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule from .base_offload_module import BaseOffloadModule
from .region import Region from .region import Region
...@@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper): ...@@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper):
hysteresis=hysteresis, hysteresis=hysteresis,
max_scale=max_scale, max_scale=max_scale,
) )
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) self._found_overflow: torch.Tensor = torch.zeros(
1, dtype=torch.int64, device=get_accelerator().get_current_device()
)
self._logger = get_dist_logger() self._logger = get_dist_logger()
def _set_grad_ptr(self): def _set_grad_ptr(self):
......
...@@ -11,7 +11,7 @@ except: ...@@ -11,7 +11,7 @@ except:
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.utils.device import get_current_device from colossalai.accelerator import get_accelerator
from .region import Region from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
...@@ -57,7 +57,10 @@ class Solver(ABC): ...@@ -57,7 +57,10 @@ class Solver(ABC):
if memory_budget > 0: if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor self.memory_budget = memory_budget * self.error_factor
else: else:
self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor self.memory_budget = (
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
* self.error_factor
)
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power() self.comp_power: float = self._extract_computing_power()
......
...@@ -5,8 +5,8 @@ import torch.nn as nn ...@@ -5,8 +5,8 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils.device import autocast
from .mixed_precision_base import MixedPrecision from .mixed_precision_base import MixedPrecision
...@@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper): ...@@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
super().__init__(module) super().__init__(module)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
with autocast(): with get_accelerator().autocast():
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
......
...@@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler ...@@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
get_model_base_filenames, get_model_base_filenames,
...@@ -27,8 +28,6 @@ from colossalai.checkpoint_io.utils import ( ...@@ -27,8 +28,6 @@ from colossalai.checkpoint_io.utils import (
from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
...@@ -366,11 +365,11 @@ class GeminiPlugin(DPPluginBase): ...@@ -366,11 +365,11 @@ class GeminiPlugin(DPPluginBase):
) -> None: ) -> None:
super().__init__() super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if IS_NPU_AVAILABLE: if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy" assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict( self.gemini_config = dict(
chunk_config_dict=chunk_config_dict, chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()), chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
placement_policy=placement_policy, placement_policy=placement_policy,
enable_gradient_accumulation=enable_gradient_accumulation, enable_gradient_accumulation=enable_gradient_accumulation,
shard_param_frac=shard_param_frac, shard_param_frac=shard_param_frac,
...@@ -455,7 +454,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -455,7 +454,7 @@ class GeminiPlugin(DPPluginBase):
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ["cuda", "npu"] return ["cuda", "npu"]
def prepare_dataloader( def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
): ):
...@@ -486,7 +485,10 @@ class GeminiPlugin(DPPluginBase): ...@@ -486,7 +485,10 @@ class GeminiPlugin(DPPluginBase):
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
sampler = DistributedSampler( sampler = DistributedSampler(
dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle dataset,
num_replicas=zero_world_size * extra_dp_world_size,
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
shuffle=shuffle,
) )
# Deterministic dataloader # Deterministic dataloader
......
...@@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map ...@@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
...@@ -28,7 +29,6 @@ from colossalai.shardformer import ShardConfig, ShardFormer ...@@ -28,7 +29,6 @@ from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.utils.device import get_current_device
from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase from .pp_plugin_base import PipelinePluginBase
...@@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): ...@@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.mixed_precision = torch.bfloat16 self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None: if self.mixed_precision is not None:
module = module.to(self.mixed_precision) module = module.to(self.mixed_precision)
module = module.to(get_current_device()) module = module.to(get_accelerator().get_current_device())
# setting input type cast when using mixed precision # setting input type cast when using mixed precision
self.convert_fn = None self.convert_fn = None
...@@ -345,7 +345,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): ...@@ -345,7 +345,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if norm_type == inf: if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients) total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1: if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if self.pp_size > 1: if self.pp_size > 1:
...@@ -385,7 +387,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): ...@@ -385,7 +387,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor( total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
) )
if self.tp_size > 1: if self.tp_size > 1:
# compute norm in tp process group # compute norm in tp process group
...@@ -542,7 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): ...@@ -542,7 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# so we need to calculate the norm of 'tp' and 'pp' gradients. # so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1: if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
...@@ -586,7 +590,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): ...@@ -586,7 +590,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor( total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
) )
if self.tp_size > 1: if self.tp_size > 1:
# compute norm in tp process group # compute norm in tp process group
...@@ -798,7 +802,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ...@@ -798,7 +802,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# so we only need to calculate the norm 'tp' of 'pp' gradients. # so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm = super()._compute_grad_norm(gradients, norm_type) total_norm = super()._compute_grad_norm(gradients, norm_type)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if tp_size > 1: if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
...@@ -838,7 +844,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ...@@ -838,7 +844,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor( total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
) )
if dp_size > 1: if dp_size > 1:
# compute norm in dp process group # compute norm in dp process group
......
...@@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler ...@@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames, get_optimizer_base_filenames,
...@@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import ( ...@@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
) )
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
...@@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): ...@@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
if self.dtype is not None: if self.dtype is not None:
module = module.to(self.dtype) module = module.to(self.dtype)
module = module.to(get_current_device()) module = module.to(get_accelerator().get_current_device())
self.module = module self.module = module
self.convert_fn = None self.convert_fn = None
if self.dtype is not None: if self.dtype is not None:
......
...@@ -6,12 +6,12 @@ import warnings ...@@ -6,12 +6,12 @@ import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, Union
import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.context import Config from colossalai.context import Config
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed from colossalai.utils import set_seed
def launch( def launch(
...@@ -47,17 +47,18 @@ def launch( ...@@ -47,17 +47,18 @@ def launch(
if rank == 0: if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.") warnings.warn("`config` is deprecated and will be removed soon.")
if IS_NPU_AVAILABLE and backend == "nccl": cur_accelerator = get_accelerator()
backend = "hccl"
backend = cur_accelerator.communication_backend
# init default process group # init default process group
init_method = f"tcp://[{host}]:{port}" init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device # set cuda device
if torch.cuda.is_available() or IS_NPU_AVAILABLE: # if local rank is not given, calculate automatically
# if local rank is not given, calculate automatically if cur_accelerator.support_set_device:
set_device(local_rank) cur_accelerator.set_device(local_rank)
set_seed(seed) set_seed(seed)
......
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
__all__ = [
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
]
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "column_remap.cuh"
#include "util.cuh"
const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;
__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
if (x_column >= x_width) return;
//if (x_row >= x_height) return;
int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;
int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;
while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}
// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _column_remap_cuh
#define _column_remap_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
);
#endif
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_compat_cuh
#define _cuda_compat_cuh
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#define _cuda_buffers_cu
#include "cuda_buffers.cuh"
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
// __constant__ half2 q4_table[16][256];
// half2 q4_table_host[16][256];
// bool q4_table_init = false;
CudaBuffers::CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
) :
device(_device),
temp_state_size(_temp_state_size),
temp_state(_temp_state),
temp_dq(_temp_dq)
{
cudaSetDevice(_device);
cudaStreamCreate(&alt_stream_1);
cudaStreamCreate(&alt_stream_2);
cudaStreamCreate(&alt_stream_3);
cudaEventCreate(&alt_stream_1_done);
cudaEventCreate(&alt_stream_2_done);
cudaEventCreate(&alt_stream_3_done);
}
CudaBuffers::~CudaBuffers()
{
cudaStreamDestroy(alt_stream_1);
cudaStreamDestroy(alt_stream_2);
cudaStreamDestroy(alt_stream_3);
cudaEventDestroy(alt_stream_1_done);
cudaEventDestroy(alt_stream_2_done);
cudaEventDestroy(alt_stream_3_done);
}
CudaBuffers* get_buffers(const int device_index)
{
return g_buffers[device_index];
}
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
)
{
CudaBuffers* buffers = new CudaBuffers
(
_device,
_temp_state_size,
_temp_state,
_temp_dq
);
g_buffers[_device] = buffers;
}
void cleanup_buffers_cuda()
{
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
{
if (!g_buffers[i]) continue;
delete g_buffers[i];
g_buffers[i] = NULL;
}
}
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_buffers_cuh
#define _cuda_buffers_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
const int CUDA_MAX_DEVICES = 16;
// #ifndef _cuda_buffers_cu
// extern __constant__ half2 q4_table[16][256];
// #endif
class CudaBuffers
{
public:
int device;
half* temp_state; // [max_hidden_rows * intermediate_size]
int temp_state_size;
half* temp_dq; // size of largest quant tensor * 8
cudaStream_t alt_stream_1;
cudaStream_t alt_stream_2;
cudaStream_t alt_stream_3;
cudaEvent_t alt_stream_1_done;
cudaEvent_t alt_stream_2_done;
cudaEvent_t alt_stream_3_done;
CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
~CudaBuffers();
};
CudaBuffers* get_buffers(const int device_index);
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
void cleanup_buffers_cuda();
#endif
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _hip_compat_cuh
#define _hip_compat_cuh
// Workaround for a bug in hipamd, backported from upstream.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}
#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
// Workaround for hipify_python using rocblas instead of hipblas.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm
#endif
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "q4_matrix.cuh"
#include "q4_matmul.cuh"
#include "column_remap.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
void check_cuda(cudaError_t ret)
{
switch (ret)
{
case cudaSuccess:
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
default:
printf(" **** CUDA error\n"); \
printf(" **** %s\n", cudaGetErrorString(ret)); \
TORCH_CHECK(false, "CUDA error"); \
break;
}
}
// Some decluttering macros
#define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
} while(0)
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while(0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
{
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
}
// Tuning parameters
ExLlamaTuning tuningParams;
void set_tuning_params
(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2;
}
// Release all unmanaged objects allocated by the extension
void cleanup()
{
cleanup_buffers_cuda();
g_q4_free_matrices();
}
// Prepare buffers for forward pass
void prepare_buffers
(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
)
{
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
prepare_buffers_cuda
(
device_index,
// buffer size used for sanity checks
temp_state.numel(),
(half*) temp_state.data_ptr(),
(half*) temp_dq.data_ptr()
);
}
// Create Q4Matrix, return handle
uintptr_t make_q4
(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix
(
height,
width,
groups,
(uint32_t*) qweight.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*) scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
device
);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t> (m);
}
// Matmul half @ quant -> half
void q4_matmul
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle()
);
}
}
// Remap columns in half tensor
void column_remap
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
int height = x.size(0);
int width = x.size(1);
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda
(
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
}
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _matrix_cuh
#define _matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
__device__ __forceinline__ half2 dot_product_8
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
half2 tmp = __hmul2(*h_ptr++, v_01);
tmp = __hfma2(*h_ptr++, v_23, tmp);
tmp = __hfma2(*h_ptr++, v_45, tmp);
tmp = __hfma2(*h_ptr++, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half* h_ptr = h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(*h_ptr++, v_0);
tmp = __hfma(*h_ptr++, v_1, tmp);
tmp = __hfma(*h_ptr++, v_2, tmp);
tmp = __hfma(*h_ptr++, v_3, tmp);
tmp = __hfma(*h_ptr++, v_4, tmp);
tmp = __hfma(*h_ptr++, v_5, tmp);
tmp = __hfma(*h_ptr++, v_6, tmp);
tmp = __hfma(*h_ptr++, v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
__device__ __forceinline__ half2 dot_product_8_x_map
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
half h_0 = h_ptr[*x_map_ptr++];
half h_1 = h_ptr[*x_map_ptr++];
half h_2 = h_ptr[*x_map_ptr++];
half h_3 = h_ptr[*x_map_ptr++];
half h_4 = h_ptr[*x_map_ptr++];
half h_5 = h_ptr[*x_map_ptr++];
half h_6 = h_ptr[*x_map_ptr++];
half h_7 = h_ptr[*x_map_ptr++];
half2 h_01 = __halves2half2(h_0, h_1);
half2 h_23 = __halves2half2(h_2, h_3);
half2 h_45 = __halves2half2(h_4, h_5);
half2 h_67 = __halves2half2(h_6, h_7);
half2 tmp = __hmul2(h_01, v_01);
tmp = __hfma2(h_23, v_23, tmp);
tmp = __hfma2(h_45, v_45, tmp);
tmp = __hfma2(h_67, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_x_map_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
#endif
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include "util.cuh"
#include "matrix.cuh"
#include "cu_compat.cuh"
#include "cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "hip_compat.cuh"
#endif
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
typedef void (*fp_q4_matmul_kernel)
(
const half*,
const uint32_t*,
half*,
const half*,
const uint32_t*,
const int,
const int,
const int,
const int,
const int,
const uint32_t*,
bool
);
template<bool use_half2, bool use_groupsize, bool use_x_map>
__global__ void q4_matmul_kernel
(
const half* __restrict__ x,
const uint32_t* __restrict__ w,
half* __restrict__ out,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int dim,
const int width,
const int groupsize,
const int block_size_z,
const uint32_t* __restrict__ x_map,
bool no_zero
)
{
// Start of block
int x_column = block_size_z * blockIdx.z;
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
int iterations = (x_column_end - x_column) / 8;
// Views
MatrixView_half x_(x, height, dim);
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
MatrixView_q4_column w_(w, dim, width);
MatrixView_half_rw out_(out, height, width);
// Zero output
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
{
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
__syncthreads();
}
// Loop over part of x row (and w column)
half2 acc = {};
half acc_h = {};
if constexpr (use_groupsize)
{
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
// could be slightly faster
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
{
if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
}
}
else
{
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
for (int k = x_column; k < x_column + iterations * 8; k += 8)
{
if constexpr (use_half2)
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
}
}
// Add to block result
if constexpr (use_half2)
{
half result = __hadd(__low2half(acc), __high2half(acc));
atomicAdd(out_.item_ptr(x_row, w_column), result);
}
else
{
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
}
}
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
{
// <bool use_half2, bool use_groupsize, bool use_x_map>
if (tuningParams->matmul_no_half2) {
if (block_size_z % groupsize == 0) {
if (x_map) return q4_matmul_kernel<false, true, true >;
else return q4_matmul_kernel<false, true, false>;
} else {
if (x_map) return q4_matmul_kernel<false, false, true >;
else return q4_matmul_kernel<false, false, false>;
}
} else {
if (block_size_z % groupsize == 0)
{
if (x_map) return q4_matmul_kernel<true, true, true >;
else return q4_matmul_kernel<true, true, false>;
} else {
if (x_map) return q4_matmul_kernel<true, false, true >;
else return q4_matmul_kernel<true, false, false>;
}
}
};
// Compute y = x @ w
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero,
cudaStream_t alt_stream
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
uint32_t* x_map = w->cuda_x_map;
const half* x_mapped = x;
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
{
CudaBuffers* buffers = get_buffers(w->device);
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
x_map = NULL;
}
int block_size_z;
if (w->width == 4096) block_size_z = 384; // 7B
else if (w->width == 11008) block_size_z = 256;
else if (w->width == 5120) block_size_z = 384; // 13B
else if (w->width == 13824) block_size_z = 256;
else if (w->width == 6656) block_size_z = 256; // 33B
else if (w->width == 17920) block_size_z = 128;
else block_size_z = 256;
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
dim3 threads(THREADS_X, THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height + threads.y - 1) / threads.y,
(dim + block_size_z - 1) / block_size_z
);
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
}
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
CudaBuffers* buffers = get_buffers(w->device);
const half* x_mapped = x;
if (w->cuda_x_map)
{
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
}
w->reconstruct(buffers->temp_dq);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
const float alpha = 1.0f;
const float beta = no_zero ? 1.0f : 0.0f;
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
#else
const half alpha = __float2half(1.0f);
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment