"llama/vscode:/vscode.git/clone" did not exist on "078f666f73422edc1a3819332c03b6f467d064f4"
Commit 37c494a7 authored by Zhekai Zhang's avatar Zhekai Zhang
Browse files

Initial release

parents
import os
import torch
from diffusers import FluxPipeline
from huggingface_hub import hf_hub_download
from ..models.flux import inject_pipeline, load_quantized_model
def quantize_t5(pipe: FluxPipeline, qencoder_path: str):
assert os.path.exists(qencoder_path), f"qencoder_path {qencoder_path} does not exist"
from deepcompressor.backend.tinychat.linear import W4Linear
named_modules = {}
qencoder_state_dict = torch.load(qencoder_path, map_location="cpu")
for name, module in pipe.text_encoder_2.named_modules():
assert isinstance(name, str)
if isinstance(module, torch.nn.Linear):
suffix = [".q", ".k", ".v", ".o", ".wi_0"]
if f"{name}.qweight" in qencoder_state_dict and name.endswith(tuple(suffix)):
print(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
qmodule.qweight.data.copy_(qencoder_state_dict[f"{name}.qweight"])
if qmodule.bias is not None:
qmodule.bias.data.copy_(qencoder_state_dict[f"{name}.bias"])
qmodule.scales.data.copy_(qencoder_state_dict[f"{name}.scales"])
qmodule.scaled_zeros.data.copy_(qencoder_state_dict[f"{name}.scaled_zeros"])
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
parent_name, child_name = name.rsplit(".", 1)
setattr(named_modules[parent_name], child_name, qmodule)
else:
named_modules[name] = module
def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> FluxPipeline:
qmodel_device = kwargs.pop("qmodel_device", "cuda:0")
qmodel_device = torch.device(qmodel_device)
if qmodel_device.type != "cuda":
raise ValueError(f"qmodel_device = {qmodel_device} is not a CUDA device")
qmodel_path = kwargs.pop("qmodel_path")
qencoder_path = kwargs.pop("qencoder_path", None)
if not os.path.exists(qmodel_path):
hf_repo_id = os.path.dirname(qmodel_path)
filename = os.path.basename(qmodel_path)
qmodel_path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
pipeline = FluxPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
m = load_quantized_model(qmodel_path, 0 if qmodel_device.index is None else qmodel_device.index)
inject_pipeline(pipeline, m)
if qencoder_path is not None:
assert isinstance(qencoder_path, str)
if not os.path.exists(qencoder_path):
hf_repo_id = os.path.dirname(qencoder_path)
filename = os.path.basename(qencoder_path)
qencoder_path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
quantize_t5(pipeline, qencoder_path)
return pipeline
import os
import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
if __name__ == "__main__":
fp = open("nunchaku/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])
ROOT_DIR = os.path.dirname(__file__)
INCLUDE_DIRS = [
"src",
"third_party/cutlass/include",
"third_party/json/include",
"third_party/mio/include",
"third_party/spdlog/include",
]
INCLUDE_DIRS = ["-I" + ROOT_DIR + "/" + dir for dir in INCLUDE_DIRS]
DEBUG = False
def ncond(s) -> list:
if DEBUG:
return []
else:
return [s]
def cond(s) -> list:
if DEBUG:
return [s]
else:
return []
CXX_FLAGS = ["-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og", *INCLUDE_DIRS]
NVCC_FLAGS = [
"-DBUILD_NUNCHAKU=1",
"-gencode", "arch=compute_86,code=sm_86",
"-gencode", "arch=compute_89,code=sm_89",
"-g",
"-std=c++20",
"-UNDEBUG",
"-Xcudafe",
"--diag_suppress=20208", # spdlog: 'long double' is treated as 'double' in device code
*cond("-G"),
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_HALF2_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--threads=2",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--generate-line-info",
"--ptxas-options=--allow-expensive-optimizations=true",
*INCLUDE_DIRS,
]
nunchaku_extension = CUDAExtension(
name="nunchaku._C",
sources=[
"nunchaku/csrc/pybind.cpp",
"src/interop/torch.cpp",
"src/activation.cpp",
"src/layernorm.cpp",
"src/Linear.cpp",
*ncond("src/FluxModel.cpp"),
"src/Serialization.cpp",
*ncond("src/kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"),
"src/kernels/activation_kernels.cu",
"src/kernels/layernorm_kernels.cu",
"src/kernels/misc_kernels.cu",
"src/kernels/gemm_w4a4.cu",
"src/kernels/gemm_batched.cu",
"src/kernels/gemm_f16.cu",
"src/kernels/awq/gemv_awq.cu",
*ncond("src/kernels/flash_attn/flash_api.cpp"),
*ncond("src/kernels/flash_attn/flash_api_adapter.cpp"),
],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
)
setuptools.setup(
name="nunchaku",
version=version,
packages=setuptools.find_packages(),
ext_modules=[nunchaku_extension],
cmdclass={"build_ext": BuildExtension},
)
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/flash_attn/flash_api.h"
#include "kernels/gemm_batched.h"
#include "activation.h"
#include <nvtx3/nvToolsExt.h>
#include <iostream>
using spdlog::fmt_lib::format;
Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
Tensor ff_output = std::get<Tensor>(fc2.forward_quant(
std::get<GEMM_W4A4::QuantizedActivation>(fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2)))
);
return ff_output;
}
// Tensor forward_mlp(GEMM_W8A8 &fc1, GEMM_W8A8 &fc2, Tensor norm_hidden_states) {
// Tensor ff_output = fc2.forward(fc1.forward(norm_hidden_states), GEMM_W8A8::FuseOptions::GELU);
// return ff_output;
// }
Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
return std::get<Tensor>(fc.forward(x));
}
// Tensor forward_fc(GEMM_W8A8 &fc, Tensor x) {
// return fc.forward(x);
// }
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device) :
dim(dim),
linear(dim, 3 * dim, true, dtype, device),
norm(dim, 1e-6, false, dtype, device)
{
registerChildren
(linear, "linear")
(norm, "norm")
;
}
AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor emb) {
debug("emb_input", emb);
emb = linear.forward(Silu::forward(emb));
debug("emb_linear", emb);
auto &&[shift_msa, scale_msa, gate_msa] = split_mod<3>(emb);
debug("scale_msa", scale_msa);
debug("shift_msa", shift_msa);
debug("x", x);
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
mul_add(norm_x, scale_msa, shift_msa);
return Output{norm_x, gate_msa};
}
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device) :
dim(dim), pre_only(pre_only),
linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
norm(dim, 1e-6, false, dtype, device)
{
registerChildren
(linear, "linear")
(norm, "norm")
;
}
AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
debug("x", x);
debug("emb_input", emb);
emb = linear.forward(Silu::forward(emb));
debug("emb_linear", emb);
if (pre_only) {
auto &&[shift_msa, scale_msa] = split_mod<2>(emb);
debug("shift_msa", shift_msa);
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
mul_add(norm_x, scale_msa, shift_msa);
debug("norm_x_scaled", norm_x);
return Output{norm_x};
} else {
auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = split_mod<6>(emb);
debug("shift_msa", shift_msa);
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
mul_add(norm_x, scale_msa, shift_msa);
debug("norm_x_scaled", norm_x);
return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
}
}
Attention::Attention(int num_heads, int dim_head, Device device) :
num_heads(num_heads), dim_head(dim_head)
{
headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
for (int i = 0; i < num_heads; i++) {
headmask_type.data_ptr<int32_t>()[i] = i + 1;
}
headmask_type = headmask_type.copy(device);
}
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
assert(qkv.ndims() == 3);
const Device device = qkv.device();
const int batch_size = qkv.shape[0];
const int num_tokens = qkv.shape[1];
assert(qkv.shape[2] == num_heads * dim_head * 3);
constexpr int POOL_SIZE = 128;
const int pool_tokens = num_tokens / POOL_SIZE;
Tensor blockmask;
if (pool_qkv.valid()) {
assert(pool_qkv.shape[0] == batch_size);
assert(pool_qkv.shape[1] == pool_tokens);
assert(pool_qkv.shape[2] == num_heads * dim_head * 3);
}
Tensor pool_score = Tensor::allocate({batch_size, num_heads, pool_tokens, pool_tokens}, Tensor::FP32, device);
if (pool_qkv.valid() && sparsityRatio > 0) {
pool_qkv = pool_qkv.view({batch_size, pool_tokens, 3, num_heads, dim_head});
pool_qkv = pool_qkv.transpose(1, 2).transpose(2, 3); // [batch_size, 3, num_heads, poolTokens, dim_head]
for (int i = 0; i < batch_size; i++) {
Tensor pool_q = pool_qkv.slice(0, i, i+1).slice(1, 0, 1);
Tensor pool_k = pool_qkv.slice(0, i, i+1).slice(1, 1, 2);
Tensor pool_s = pool_score.slice(0, i, i+1);
gemm_batched_fp16(pool_q, pool_k, pool_s);
}
}
blockmask = topk(pool_score, pool_tokens * (1 - sparsityRatio));
if (cu_seqlens_cpu.valid()) {
if (cu_seqlens_cpu.shape[0] != batch_size + 1) {
cu_seqlens_cpu = Tensor{};
} else {
for (int i = 0; i <= batch_size; i++) {
if (cu_seqlens_cpu.data_ptr<int32_t>()[i] != num_tokens * i) {
cu_seqlens_cpu = Tensor{};
break;
}
}
}
}
if (!cu_seqlens_cpu.valid()) {
cu_seqlens_cpu = Tensor::allocate({batch_size + 1}, Tensor::INT32, Device::cpu());
cu_seqlens_cpu.data_ptr<int32_t>()[0] = 0;
for (int i = 1; i <= batch_size; i++) {
cu_seqlens_cpu.data_ptr<int32_t>()[i] = cu_seqlens_cpu.data_ptr<int32_t>()[i - 1] + num_tokens;
}
}
Tensor cu_seqlens = cu_seqlens_cpu.copy(device);
Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head});
Tensor q = reshaped.slice(1, 0, num_heads);
Tensor k = reshaped.slice(1, num_heads, num_heads * 2);
Tensor v = reshaped.slice(1, num_heads * 2, num_heads * 3);
spdlog::debug("q,k,v={}", q.shape.str());
Tensor raw_attn_output = mha_fwd_block(
q, k, v,
cu_seqlens, cu_seqlens,
POOL_SIZE, POOL_SIZE,
headmask_type,
{},
blockmask,
num_tokens,
num_tokens,
0.0f,
pow(q.shape[-1], (-0.5)),
false, false, false, -1, -1
).front();
/**
Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
cu_seqlens,
cu_seqlens,
concat.shape[1],
concat.shape[1],
0.0f,
pow(q.shape[-1], (-0.5)),
false,
true,
-1, -1,
false
).front();
Tensor raw_attn_output = mha_fwd(q, k, v,
0.0f,
pow(q.shape[-1], (-0.5)),
false, -1, -1, false
).front();
Tensor raw_attn_output = mha_varlen_fwd(
q, k, v,
cu_seqlens, cu_seqlens,
num_tokens_img + num_tokens_context, num_tokens_img + num_tokens_context,
0.0f,
pow(q.shape[-1], (-0.5)),
false, false, -1, -1, false
).front();
**/
assert(raw_attn_output.shape[0] == batch_size * num_tokens);
assert(raw_attn_output.shape[1] == num_heads);
assert(raw_attn_output.shape[2] == dim_head);
return raw_attn_output;
}
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, Tensor::ScalarType dtype, Device device) :
dim(dim),
dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads),
mlp_hidden_dim(dim * mlp_ratio),
norm(dim, dtype, device),
mlp_fc1(dim, mlp_hidden_dim, true, dtype, device),
mlp_fc2(mlp_hidden_dim, dim, true, dtype, device),
qkv_proj(dim, dim * 3, true, dtype, device),
norm_q(dim_head, 1e-6, false, dtype, device),
norm_k(dim_head, 1e-6, false, dtype, device),
attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
out_proj(dim, dim, true, dtype, device)
{
registerChildren
(norm, "norm")
(mlp_fc1, "mlp_fc1")
(mlp_fc2, "mlp_fc2")
(qkv_proj, "qkv_proj")
(norm_q, "norm_q")
(norm_k, "norm_k")
(out_proj, "out_proj")
;
}
Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb) {
nvtxRangePushA("FluxSingleTransformerBlock");
const int batch_size = hidden_states.shape[0];
const int num_tokens = hidden_states.shape[1];
auto &&[norm_hidden_states, gate] = this->norm.forward(hidden_states, temb);
debug("norm_hidden_states", norm_hidden_states);
debug("gate", gate);
Tensor residual = hidden_states;
Tensor qkv = Tensor::allocate({batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
debug("rotary_emb", rotary_emb);
qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Tensor attn_output = attn.forward(qkv, {}, 0);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
debug("raw_attn_output", attn_output);
attn_output = forward_fc(out_proj, attn_output);
debug("attn_output", attn_output);
Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
debug("ff_output", ff_output);
hidden_states = add(attn_output, ff_output);
debug("attn_ff_output", hidden_states);
mul_add(hidden_states, gate, residual);
nvtxRangePop();
return hidden_states;
}
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, Tensor::ScalarType dtype, Device device) :
dim(dim),
dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads),
context_pre_only(context_pre_only),
norm1(dim, false, dtype, device),
norm1_context(dim, context_pre_only, dtype, device),
qkv_proj(dim, dim * 3, true, dtype, device),
qkv_proj_context(dim, dim * 3, true, dtype, device),
norm_q(dim_head, 1e-6, false, dtype, device),
norm_k(dim_head, 1e-6, false, dtype, device),
norm_added_q(dim_head, 1e-6, false, dtype, device),
norm_added_k(dim_head, 1e-6, false, dtype, device),
attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
out_proj(dim, dim, true, dtype, device),
out_proj_context(dim, dim, true, dtype, device),
norm2(dim, 1e-6, false, dtype, device),
norm2_context(dim, 1e-6, false, dtype, device),
mlp_fc1(dim, dim * 4, true, dtype, device),
mlp_fc2(dim * 4, dim, true, dtype, device),
mlp_context_fc1(dim, dim * 4, true, dtype, device),
mlp_context_fc2(dim * 4, dim, true, dtype, device)
{
registerChildren
(norm1, "norm1")
(norm1_context, "norm1_context")
(qkv_proj, "qkv_proj")
(qkv_proj_context, "qkv_proj_context")
(norm_q, "norm_q")
(norm_k, "norm_k")
(norm_added_q, "norm_added_q")
(norm_added_k, "norm_added_k")
(out_proj, "out_proj")
(out_proj_context, "out_proj_context")
(norm2, "norm2")
(norm2_context, "norm2_context")
(mlp_fc1, "mlp_fc1")
(mlp_fc2, "mlp_fc2")
(mlp_context_fc1, "mlp_context_fc1")
(mlp_context_fc2, "mlp_context_fc2")
;
}
// hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim]
std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio) {
int batch_size = hidden_states.shape[0];
assert(encoder_hidden_states.shape[0] == batch_size);
nvtxRangePushA("JointTransformerBlock");
nvtxRangePushA("AdaNorm");
int num_tokens_img = hidden_states.shape[1];
int num_tokens_context = encoder_hidden_states.shape[1];
assert(hidden_states.shape[2] == dim);
assert(encoder_hidden_states.shape[2] == dim);
spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}", hidden_states.shape.str(), encoder_hidden_states.shape.str(), temb.shape.str());
spdlog::debug("batch_size={} num_tokens_img={} num_tokens_context={}", batch_size, num_tokens_img, num_tokens_context);
auto norm1_output = norm1.forward(hidden_states, temb);
auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);
#if 0
norm1_output.x = hidden_states;
norm1_context_output.x = encoder_hidden_states;
#endif
debug("norm_hidden_states", norm1_output.x);
debug("norm_encoder_hidden_states", norm1_context_output.x);
constexpr int POOL_SIZE = Attention::POOL_SIZE;
nvtxRangePop();
auto stream = getCurrentCUDAStream();
Tensor concat;
Tensor pool;
{
nvtxRangePushA("qkv_proj");
const bool blockSparse = sparsityRatio > 0;
const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_context / POOL_SIZE;
concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_context, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());
pool = blockSparse
? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
: Tensor{};
for (int i = 0; i < batch_size; i++) {
// img first
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_context);
Tensor pool_qkv = pool.valid()
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
: Tensor{};
Tensor pool_qkv_context = pool.valid()
? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_context / POOL_SIZE)
: Tensor{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
debug("rotary_emb", rotary_emb);
qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context);
debug("rotary_emb_context", rotary_emb_context);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context);
debug("qkv_context", qkv_context);
}
nvtxRangePop();
}
spdlog::debug("concat={}", concat.shape.str());
debug("concat", concat);
assert(concat.shape[2] == num_heads * dim_head * 3);
nvtxRangePushA("Attention");
Tensor raw_attn_output = attn.forward(concat, pool, sparsityRatio);
nvtxRangePop();
spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_context, num_heads, dim_head});
debug("raw_attn_output", raw_attn_output);
{
nvtxRangePushA("o_proj");
auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;
// raw_attn_output: [batch_size, num_tokens_img + num_tokens_context, num_heads * dim_head]
Tensor raw_attn_output_split;
if (batch_size == 1) {
raw_attn_output_split = raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
} else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
checkCUDA(cudaMemcpy2DAsync(
raw_attn_output_split.data_ptr(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr(),
(num_tokens_img + num_tokens_context) * num_heads * dim_head * raw_attn_output.scalar_size(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size,
cudaMemcpyDeviceToDevice,
stream));
}
spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
debug("img.raw_attn_output_split", raw_attn_output_split);
Tensor attn_output = forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
debug("img.attn_output", attn_output);
#if 1
mul_add(attn_output, gate_msa, hidden_states);
hidden_states = std::move(attn_output);
nvtxRangePop();
nvtxRangePushA("MLP");
spdlog::debug("attn_output={}", hidden_states.shape.str());
Tensor norm_hidden_states = norm2.forward(hidden_states);
debug("scale_mlp", scale_mlp);
debug("shift_mlp", shift_mlp);
mul_add(norm_hidden_states, scale_mlp, shift_mlp);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
Tensor norm_hidden_states = hidden_states;
#endif
// Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
debug("img.ff_input", norm_hidden_states);
Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
debug("img.ff_output", ff_output);
debug("gate_mlp", gate_mlp);
mul_add(ff_output, gate_mlp, hidden_states);
hidden_states = std::move(ff_output);
nvtxRangePop();
spdlog::debug("ff_output={}", hidden_states.shape.str());
}
if (context_pre_only) {
return { hidden_states, encoder_hidden_states };
}
{
nvtxRangePushA("o_proj_context");
auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;
Tensor raw_attn_output_split;
if (batch_size == 1) {
raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img, num_tokens_img + num_tokens_context).reshape({batch_size, num_tokens_context, num_heads * dim_head});
} else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_context, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
checkCUDA(cudaMemcpy2DAsync(
raw_attn_output_split.data_ptr(),
num_tokens_context * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr<char>() + num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
(num_tokens_img + num_tokens_context) * num_heads * dim_head * raw_attn_output.scalar_size(),
num_tokens_context * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size,
cudaMemcpyDeviceToDevice,
stream));
}
spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
debug("context.raw_attn_output_split", raw_attn_output_split);
Tensor attn_output = forward_fc(out_proj_context, raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
debug("context.attn_output", attn_output);
#if 1
mul_add(attn_output, gate_msa, encoder_hidden_states);
encoder_hidden_states = std::move(attn_output);
nvtxRangePop();
nvtxRangePushA("MLP");
spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());
Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
debug("c_scale_mlp", scale_mlp);
debug("c_shift_mlp", shift_mlp);
mul_add(norm_hidden_states, scale_mlp, shift_mlp);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
auto norm_hidden_states = encoder_hidden_states;
#endif
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
debug("context.ff_input", norm_hidden_states);
Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
debug("context.ff_output", ff_output);
debug("c_gate_mlp", gate_mlp);
mul_add(ff_output, gate_mlp, encoder_hidden_states);
encoder_hidden_states = std::move(ff_output);
nvtxRangePop();
spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
}
nvtxRangePop();
return { hidden_states, encoder_hidden_states };
}
FluxModel::FluxModel(Tensor::ScalarType dtype, Device device) {
for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
}
for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, dtype, Device::cuda()));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
}
}
Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single) {
const int batch_size = hidden_states.shape[0];
const Tensor::ScalarType dtype = hidden_states.dtype();
const Device device = hidden_states.device();
const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];
for (auto &&block : transformer_blocks) {
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
}
// txt first, same as diffusers
Tensor concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
}
hidden_states = concat;
encoder_hidden_states = {};
for (auto &&block : single_transformer_blocks) {
hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
}
return hidden_states;
}
\ No newline at end of file
#pragma once
#include "common.h"
#include "Tensor.h"
#include "Module.h"
#include "Linear.h"
#include "layernorm.h"
class AdaLayerNormZeroSingle : public Module {
public:
static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>;
struct Output {
Tensor x;
Tensor gate_msa;
};
public:
AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device);
Output forward(Tensor x, Tensor emb);
public:
const int dim;
private:
GEMM linear;
LayerNorm norm;
};
class AdaLayerNormZero : public Module {
public:
static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>;
struct Output {
Tensor x;
Tensor gate_msa;
Tensor shift_mlp;
Tensor scale_mlp;
Tensor gate_mlp;
};
public:
AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device);
Output forward(Tensor x, Tensor emb);
public:
const int dim;
const bool pre_only;
private:
GEMM linear;
LayerNorm norm;
};
class Attention {
public:
static constexpr int POOL_SIZE = 128;
Attention(int num_heads, int dim_head, Device device);
Tensor forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio);
public:
const int num_heads;
const int dim_head;
private:
Tensor cu_seqlens_cpu;
Tensor headmask_type;
};
class FluxSingleTransformerBlock : public Module {
public:
static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb);
public:
const int dim;
const int dim_head;
const int num_heads;
const int mlp_hidden_dim;
private:
AdaLayerNormZeroSingle norm;
GEMM mlp_fc1;
GEMM mlp_fc2;
GEMM qkv_proj;
RMSNorm norm_q, norm_k;
Attention attn;
GEMM out_proj;
};
class JointTransformerBlock : public Module {
public:
static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, Tensor::ScalarType dtype, Device device);
std::tuple<Tensor, Tensor> forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio);
public:
const int dim;
const int dim_head;
const int num_heads;
const bool context_pre_only;
private:
AdaLayerNormZero norm1;
AdaLayerNormZero norm1_context;
GEMM qkv_proj;
GEMM qkv_proj_context;
RMSNorm norm_q, norm_k;
RMSNorm norm_added_q, norm_added_k;
Attention attn;
GEMM out_proj;
GEMM out_proj_context;
LayerNorm norm2;
LayerNorm norm2_context;
GEMM mlp_fc1, mlp_fc2;
GEMM mlp_context_fc1, mlp_context_fc2;
};
class FluxModel : public Module {
public:
FluxModel(Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
public:
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
};
\ No newline at end of file
#include "Linear.h"
#include "kernels/gemm_w4a4.h"
#include "kernels/gemm_f16.h"
#include "kernels/misc_kernels.h"
#include "kernels/awq/gemv_awq.h"
GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f)
{
this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device);
this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
this->wzeros = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
// !!! lora layout is different from w4a4 !!!
this->lora_down = Tensor::allocate({lora_rank, in_features}, dtype, device, true);
this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true);
registerParams
(qweight, "qweight")
(wscales, "wscales")
(wzeros, "wzeros")
(bias, "bias")
(lora_down, "lora_down", ParamFlags::Optional)
(lora_up, "lora_up", ParamFlags::Optional)
;
}
void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->qweight.device());
if (key == "lora_down") {
const int new_rank = dst.shape[0];
this->lora_rank = new_rank;
}
} else {
dst.copy_(src);
}
} else {
Module::loadParam(key, dst, src);
}
}
Tensor GEMV_AWQ::forward(Tensor x) {
debug("x", x);
const int M = (int)x.numel() / x.shape[-1];
Tensor out = gemv_awq(x, this->qweight, this->wscales, this->wzeros, M, out_features, in_features, group_size);
if (bias.valid()) {
// TODO: batch
assert(out.numel() == bias.numel());
out = add(out, bias.view(out.shape.dataExtent));
}
debug("out_before_lora", out);
if (this->lora_rank > 0) {
Tensor lora_act = gemm_f16(x, this->lora_down, {}, 1.0f, 0.0f);
debug("lora_act", lora_act);
Tensor lora_out = gemm_f16(lora_act, this->lora_up, {}, this->lora_scale, 0.0f);
debug("lora_out", lora_out);
out = add(out, lora_out);
}
debug("out", out);
return out;
}
#define NO_LORA_FUSION 0
GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), lora_rank(0), dtype(dtype)
{
this->qweight = Tensor::allocate({out_features, in_features / 2}, Tensor::INT8, device, true);
this->wscales = Tensor::allocate({in_features / 64, out_features}, dtype, device, true);
this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};
this->lora_down = Tensor::allocate({in_features, lora_rank}, dtype, device, true);
this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true);
// TODO: smooth factor in FC1+FC2 fusion
// TODO: smooth factor in non-Lora fusion
this->smooth = Tensor::allocate({in_features}, dtype, device, true);
registerParams
(qweight, "qweight")
(wscales, "wscales")
(this->bias, "bias")
(lora_down, "lora_down", ParamFlags::Optional)
(lora_up, "lora_up", ParamFlags::Optional)
(smooth, "smooth")
;
#if NO_LORA_FUSION
checkCUBLAS(cublasCreate(&handle));
#endif
}
void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->qweight.device());
this->lora_rank = dst.shape[1];
this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f);
} else {
dst.copy_(src);
}
} else {
Module::loadParam(key, dst, src);
}
}
std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
return forward_quant(quantize(x), fuse, nextGEMM);
}
void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor norm_k, Tensor rotary_emb) {
QuantizedActivation qact = quantize(x);
#if !NO_LORA_FUSION
#if 0
Tensor dummy = Tensor::empty_like(qact.lora_act);
dummy.zero_();
gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, dummy, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, qact.is_unsigned);
debug("gemm.nolora.out", out);
#endif
gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, qact.is_unsigned, this->lora_scales);
debug("gemm.out", out);
#else
const int M = (int)qact.act.numel() / qact.act.shape[-1];
gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, {}, {}, {}, {});
nvtxRangePushA("LoraUp");
static const half one = 1.0;
static const half zero = 0.0;
// lora_up: [M, R] * [OC, R] => [M, OC]
// cublas view: [OC, R] * [M, R]^T
checkCUBLAS(cublasHgemm(
handle,
CUBLAS_OP_T, CUBLAS_OP_N,
this->out_features, M, this->lora_rank,
&one,
this->lora_up.data_ptr<half>(),
this->lora_rank,
qact.lora_act.data_ptr<half>(),
this->lora_rank,
&one,
out.data_ptr<half>(),
this->out_features));
nvtxRangePop();
#endif
}
std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
Tensor out;
QuantizedActivation qout;
Tensor next_lora;
Tensor next_smooth;
const int M = (int)qact.act.numel() / qact.act.shape[-1];
if (fuse == FuseOptions::EMPTY) {
auto shape = TensorShape(qact.act.shape.dataExtent);
shape[-1] = out_features;
out = Tensor::allocate(shape, dtype, qweight.device());
} else {
auto shape = TensorShape(qact.act.shape.dataExtent);
shape[-1] = out_features / 2;
qout.act = Tensor::allocate(shape, Tensor::INT8, qweight.device());
qout.ascales = Tensor::allocate({out_features / 64, M}, dtype, qweight.device());
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qout.is_unsigned = true;
next_lora = nextGEMM->lora_down;
next_smooth = nextGEMM->smooth;
}
#if !NO_LORA_FUSION
#if 0
Tensor dummy = Tensor::empty_like(qact.lora_act);
dummy.zero_();
gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, dummy, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, qact.is_unsigned);
if (fuse == FuseOptions::EMPTY) {
debug("gemm.nolora.out", out);
} else {
debug("gemm.nolora.qout", qout.act);
debug("gemm.nolora.oscales", qout.ascales);
debug("gemm.nolora.lora_act_out", qout.lora_act);
}
#endif
gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, qact.is_unsigned, this->lora_scales);
if (fuse == FuseOptions::EMPTY) {
debug("gemm.out", out);
} else {
debug("gemm.qout", qout.act);
debug("gemm.oscales", qout.ascales);
debug("gemm.lora_act_out", qout.lora_act);
}
#else
if (!out.valid()) {
auto shape = TensorShape(qact.act.shape.dataExtent);
shape[-1] = out_features;
out = Tensor::allocate(shape, Tensor::FP16, qweight.device());
}
gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, {}, {}, {}, {});
nvtxRangePushA("LoraUp");
static const half one = 1.0;
static const half zero = 0.0;
// lora_up: [M, R] * [OC, R]^T => [M, OC]
// cublas view: [R, OC]^T * [R, M] => [OC, M]
// lora_up layout wrong?
checkCUBLAS(cublasHgemm(
handle,
CUBLAS_OP_T, CUBLAS_OP_N,
this->out_features, M, this->lora_rank,
&one,
this->lora_up.data_ptr<half>(),
this->lora_rank,
qact.lora_act.data_ptr<half>(),
this->lora_rank,
&one,
out.data_ptr<half>(),
this->out_features));
nvtxRangePop();
if (fuse == FuseOptions::GELU_QUANT) {
nvtxRangePushA("LoraDown");
// IC is for next lora (OC of this layer)
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] => [R, M]
checkCUBLAS(cublasHgemm(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
this->lora_rank, M, this->out_features,
&one,
next_lora.data_ptr<half>(),
this->lora_rank,
out.data_ptr<half>(),
this->out_features,
&zero,
qout.lora_act.data_ptr<half>(),
this->lora_rank));
out = {};
nvtxRangePop();
}
#endif
if (out.valid()) {
return out;
}
return qout;
}
GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x) {
const int M = x.numel() / x.shape[-1];
auto shape = TensorShape(x.shape.dataExtent);
shape[-1] = in_features / 2;
QuantizedActivation qact;
qact.act = Tensor::allocate(shape, Tensor::INT8, qweight.device());
qact.ascales = Tensor::allocate({in_features / 64, M}, dtype, qweight.device());
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qact.is_unsigned = false;
#if !NO_LORA_FUSION
debug("quantize.x", x);
debug("quantize.smooth", this->smooth);
quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth);
debug("quantize.qact", qact.act);
debug("quantize.ascales", qact.ascales);
debug("quantize.lora_act", qact.lora_act);
#else
static const half one = 1.0;
static const half zero = 0.0;
nvtxRangePushA("LoraDown");
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M]
checkCUBLAS(cublasHgemm(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
this->lora_rank, M, this->in_features,
&one,
lora_down.data_ptr<half>(),
this->lora_rank,
x.data_ptr<half>(),
this->in_features,
&zero,
qact.lora_act.data_ptr<half>(),
this->lora_rank));
nvtxRangePop();
quantize_w4a4_act(x, qact.act, qact.ascales);
#endif
return qact;
}
#pragma once
#include "common.h"
#include "Tensor.h"
#include "Module.h"
class GEMV_AWQ : public Module {
public:
GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x);
protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) override;
public:
const int in_features;
const int out_features;
const int group_size;
int lora_rank;
float lora_scale;
public:
Tensor qweight;
Tensor wscales;
Tensor wzeros;
Tensor bias;
Tensor lora_down;
Tensor lora_up;
// std::shared_ptr<CUBLASWrapper> cublas;
};
class GEMM_W4A4 : public Module {
public:
enum class FuseOptions {
EMPTY = 0,
GELU_QUANT,
};
struct QuantizedActivation {
Tensor act;
Tensor ascales;
Tensor lora_act;
bool is_unsigned = false;
};
public:
GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse = FuseOptions::EMPTY, GEMM_W4A4 *nextGEMM = nullptr);
void forward(Tensor x, Tensor out, Tensor pool = {}, Tensor norm_q = {}, Tensor norm_k = {}, Tensor rotary_emb = {});
std::variant<Tensor, QuantizedActivation> forward_quant(QuantizedActivation qact, FuseOptions fuse = FuseOptions::EMPTY, GEMM_W4A4 *nextGEMM = nullptr);
public:
QuantizedActivation quantize(Tensor x);
public:
const int in_features;
const int out_features;
int lora_rank;
std::vector<float> lora_scales; // every 16 ranks share a scale
const Tensor::ScalarType dtype;
protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) override;
public:
Tensor qweight;
Tensor wscales;
Tensor bias;
Tensor lora_down;
Tensor lora_up;
Tensor smooth;
cublasHandle_t handle;
};
// TODO
class GEMM_W8A8;
\ No newline at end of file
#pragma once
#include "common.h"
#include "Tensor.h"
#include "debug.h"
class Module {
protected:
enum class ParamFlags : int {
None = 0,
Optional = 1,
};
struct Param {
Tensor *tensor;
ParamFlags flags;
};
friend inline ParamFlags operator|(ParamFlags lhs, ParamFlags rhs) {
return static_cast<ParamFlags>(static_cast<int>(lhs) | static_cast<int>(rhs));
}
friend inline ParamFlags operator&(ParamFlags lhs, ParamFlags rhs) {
return static_cast<ParamFlags>(static_cast<int>(lhs) & static_cast<int>(rhs));
}
public:
std::string getFullName() const {
if (!parent) {
return name;
}
std::string fullName = parent->getFullName();
if (fullName.empty()) {
return name;
} else {
return fullName + "." + name;
}
}
void traverse(std::function<void(Module *)> func) {
func(this);
for (Module *c : this->children) {
c->traverse(func);
}
}
virtual void loadParams(TensorsProvider &provider, bool partial = false) {
for (Module *c : children) {
c->loadParams(provider, partial);
}
std::string fullName = getFullName();
std::string prefix = fullName.empty() ? "" : fullName + ".";
for (auto &&[key, param] : params) {
Tensor src = provider.getTensor(prefix + key);
if (!src.valid()) {
if (partial || int(param.flags & ParamFlags::Optional)) {
continue;
}
throw std::runtime_error(spdlog::fmt_lib::format("Tensor {} not found", prefix + key));
}
this->loadParam(key, *param.tensor, src);
// tensor->copy_(src);
}
}
void setName(std::string name) {
assert(!parent);
this->name = std::move(name);
}
protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) {
dst.copy_(src);
}
struct ChildrenRegisterHelper {
ChildrenRegisterHelper(Module &self) : self(self) {}
Module &self;
ChildrenRegisterHelper operator()(Module &module, std::string name) {
return self.registerChildren(module, name);
}
};
ChildrenRegisterHelper registerChildren(Module &module, std::string name) {
module.parent = this;
module.name = name;
children.push_back(&module);
return ChildrenRegisterHelper(*this);
}
struct ParamsRegisterHelper {
ParamsRegisterHelper(Module &self) : self(self) {}
Module &self;
ParamsRegisterHelper operator()(Tensor &param, std::string name, ParamFlags flags = ParamFlags::None) {
return self.registerParams(param, name, flags);
}
};
ParamsRegisterHelper registerParams(Tensor &param, std::string name, ParamFlags flags = ParamFlags::None) {
if (param.valid()) {
params[name].tensor = &param;
params[name].flags = flags;
}
return ParamsRegisterHelper(*this);
}
void debug(std::string name, Tensor tensor) {
if (DebugContext::ctxs.empty()) {
return;
}
std::string prefix = getFullName();
if (!prefix.empty()) {
prefix += ".";
}
tensor = tensor.copy(Device::cpu());
for (auto &&ctx : DebugContext::ctxs) {
ctx->tensors[prefix + name] = tensor;
}
}
public:
Module *parent = nullptr;
std::string name = "";
std::vector<Module *> children;
std::map<std::string, Param> params;
};
#include "Serialization.h"
#include <nlohmann/json.hpp>
#include <mio/mmap.hpp>
// #include <sys/mman.h>
using json = nlohmann::json;
using spdlog::fmt_lib::format;
class SafeTensors::mmap_file : public mio::mmap_source {
public:
mmap_file(std::string_view filename) : mio::mmap_source(filename, 0, mio::map_entire_file) {}
};
SafeTensors::SafeTensors(std::string_view filename) {
std::error_code ec;
this->mapped = std::make_unique<mmap_file>(filename);
if (ec) {
throw std::system_error(ec);
}
// char *ptr = (char *)malloc(1024);
// checkCUDA(cudaHostRegister(ptr, 1024, cudaHostRegisterDefault));
if (cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable | cudaHostRegisterReadOnly) != cudaSuccess) {
spdlog::warn("Unable to pin memory: {}", cudaGetErrorString(cudaGetLastError()));
// mlock(const_cast<char *>(this->mapped->data()), this->mapped->size());
}
parseHeader();
}
SafeTensors::~SafeTensors() {
checkCUDA(cudaHostUnregister(const_cast<char *>(this->mapped->data())));
}
void SafeTensors::parseHeader() {
static const std::unordered_map<std::string, Tensor::ScalarType> mapDType = {
{ "BF16", Tensor::BF16 },
{ "F16", Tensor::FP16 },
{ "F32", Tensor::FP32 },
{ "I8", Tensor::INT8 },
{ "I32", Tensor::INT32 },
{ "I64", Tensor::INT64 },
};
auto check = [](bool cond, std::source_location location = std::source_location::current()) {
if (!cond) {
throw std::runtime_error(format("Safetensors check failed at {}:{}", location.file_name(), location.line()));
}
};
check(this->mapped->size() > 8);
uint64_t sizeHeader = *reinterpret_cast<const uint64_t *>(this->mapped->data());
check(this->mapped->size() - 8 >= sizeHeader);
json header = json::parse(this->mapped->begin() + 8, this->mapped->begin() + 8 + sizeHeader);
const uint64_t offsetMax = this->mapped->size() - sizeHeader - 8;
std::set<size_t> offsets;
for (auto &&[key, info] : header.items()) {
if (key == "__metadata__") {
continue;
}
auto dtype = mapDType.at(info["dtype"].get<std::string>());;
auto shape = info["shape"].get<std::vector<int>>();
auto data_offsets = info["data_offsets"].get<std::vector<uint64_t>>();
check(data_offsets.size() == 2);
check(data_offsets[0] <= data_offsets[1]);
check(data_offsets[0] < offsetMax);
check(data_offsets[1] <= offsetMax);
for (auto &&dim : shape) {
check(dim >= 0);
}
TensorInfo tinfo;
tinfo.type = dtype;
tinfo.shape = TensorShape(shape);
tinfo.length = data_offsets[1] - data_offsets[0];
tinfo.offset = 8 + sizeHeader + data_offsets[0];
// TODO: check range overlap
check(!offsets.contains(tinfo.offset));
offsets.insert(tinfo.offset);
check(tinfo.shape.size() * Tensor::scalarSize.at(tinfo.type) <= tinfo.length);
tensors[key] = tinfo;
}
}
Tensor SafeTensors::getTensor(const std::string &key) {
if (!tensors.contains(key)) {
return Tensor{};
}
TensorInfo &info = tensors.at(key);
std::shared_ptr<BufferMMap> buffer = info.buffer.lock();
if (!buffer) {
buffer = std::make_shared<BufferMMap>(const_cast<char *>(this->mapped->data() + info.offset), info.length, shared_from_this());
info.buffer = buffer;
}
Tensor result;
result.shape = info.shape;
result.scalarType = info.type;
result.buffer = buffer;
return result;
}
#pragma once
#include "common.h"
#include "Tensor.h"
class BufferMMap : public Buffer {
public:
BufferMMap(void *ptr, size_t size, std::shared_ptr<void> parent) : parent(parent) {
this->size = size;
this->device.type = Device::CPU;
this->ptr = ptr;
// auto ret = cudaHostRegister(ptr, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
// if (ret == cudaSuccess) {
// this->registered = true;
// } else {
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size, cudaGetErrorString(cudaGetLastError())));
// this->registered = false;
// }
}
virtual ~BufferMMap() {
// if (registered) {
// checkCUDA(cudaHostUnregister(ptr));
// }
}
public:
std::shared_ptr<void> parent;
// bool registered;
};
class SafeTensors : public TensorsProvider, public std::enable_shared_from_this<SafeTensors> {
public:
SafeTensors(std::string_view filename);
~SafeTensors();
virtual bool contains(const std::string &key) const override {
return tensors.contains(key);
}
virtual Tensor getTensor(const std::string &key) override;
private:
void parseHeader();
private:
class mmap_file;
struct TensorInfo {
TensorShape shape;
Tensor::ScalarType type;
size_t offset;
size_t length;
std::weak_ptr<BufferMMap> buffer;
};
std::map<std::string, TensorInfo> tensors;
std::unique_ptr<mmap_file> mapped;
};
\ No newline at end of file
#pragma once
#include "common.h"
struct Device {
enum Type {
INVALID_DEVICE_TYPE = 0,
CPU, CUDA
};
Type type = INVALID_DEVICE_TYPE;
int idx = 0;
static constexpr Device cpu(int idx = 0) {
return Device{CPU, idx};
}
static constexpr Device cuda(int idx = 0) {
return Device{CUDA, idx};
}
};
// template<bool readonly>
class Buffer : public std::enable_shared_from_this<Buffer> {
public:
virtual ~Buffer() {}
void *getPtr() { return ptr; }
template<typename T>
T *getPtr() { return reinterpret_cast<T *>(ptr); }
size_t getSize() { return size; }
Device getDevice() { return device; }
protected:
template <typename Derived>
std::shared_ptr<Derived> shared_from_base() {
return std::static_pointer_cast<Derived>(shared_from_this());
}
protected:
// std::conditional_t<readonly, const void *, void *> ptr;
void *ptr;
size_t size;
Device device;
};
// using Buffer = BufferTemplate<false>;
// using BufferReadonly = BufferTemplate<true>;
class BufferMalloc : public Buffer {
public:
BufferMalloc(size_t size) {
this->size = size;
this->device.type = Device::CPU;
this->ptr = malloc(size);
}
virtual ~BufferMalloc() {
free(this->ptr);
}
};
class BufferHost : public Buffer {
public:
BufferHost(size_t size) {
this->size = size;
this->device.type = Device::CPU;
checkCUDA(cudaHostAlloc(&this->ptr, size, cudaHostAllocPortable));
}
virtual ~BufferHost() {
checkCUDA(cudaFreeHost(this->ptr));
}
};
class BufferCUDA : public Buffer {
public:
BufferCUDA(size_t size) {
this->size = size;
this->device.type = Device::CUDA;
checkCUDA(cudaGetDevice(&this->device.idx));
if (size == 0) {
this->ptr = nullptr;
}
checkCUDA(cudaMallocAsync(&this->ptr, size, 0)); // use default stream to sync with all other streams
}
virtual ~BufferCUDA() {
if (this->size == 0) {
assert(!this->ptr);
return;
}
checkCUDA(cudaFreeAsync(this->ptr, 0));
}
};
class BufferCUDASync : public Buffer {
public:
BufferCUDASync(size_t size) {
this->size = size;
this->device.type = Device::CUDA;
checkCUDA(cudaGetDevice(&this->device.idx));
checkCUDA(cudaMalloc(&this->ptr, size));
}
virtual ~BufferCUDASync() {
checkCUDA(cudaFree(this->ptr));
}
};
class BufferView : public Buffer {
public:
BufferView(std::shared_ptr<Buffer> reference, size_t offset, size_t size) : reference(reference) {
assert(offset + size <= reference->getSize());
this->ptr = (void *)((std::uint8_t *)reference->getPtr() + offset);
this->size = size;
this->device = reference->getDevice();
}
private:
std::shared_ptr<Buffer> reference;
};
struct TensorShape {
std::vector<int> dataExtent;
std::vector<int> dataStride;
int64_t offset = 0;
TensorShape() {}
TensorShape(std::vector<int> shape) : dataExtent(std::move(shape)) {}
TensorShape(std::initializer_list<int> dims) : dataExtent(dims) {}
bool is_contiguous() const {
if (dataStride.empty()) {
return true;
}
if (size() == 0) {
return true;
}
int64_t prod = 1;
for (int i = dataExtent.size() - 1; i >= 0; i--) {
if (dataExtent[i] > 1 && dataStride[i] != prod) {
return false;
}
prod *= dataExtent[i];
}
return true;
}
int ndims() const {
return dataExtent.size();
}
const int &operator[](int idx) const {
if (idx < 0) {
return dataExtent.at(dataExtent.size() + idx);
} else {
return dataExtent.at(idx);
}
}
int &operator[](int idx) {
return const_cast<int &>(const_cast<const TensorShape *>(this)->operator[](idx));
}
size_t stride(int idx) const {
if (!dataStride.empty()) {
if (idx < 0) {
return dataStride.at(dataStride.size() + idx);
} else {
return dataStride.at(idx);
}
}
if (idx < 0) {
idx = dataExtent.size() + idx;
}
assert(idx >= 0 && (size_t)idx < dataExtent.size());
size_t result = 1;
for (size_t i = idx + 1; i < dataExtent.size(); i++) {
assert(dataExtent[i] >= 0);
result *= dataExtent[i];
}
return result;
}
size_t size() const {
if (dataExtent.empty()) {
return 0;
}
size_t result = 1;
for (int dim : dataExtent) {
assert(dim >= 0);
result *= dim;
}
return result;
}
std::string str() const {
if (dataExtent.empty()) {
return "[]";
}
std::stringstream ss;
ss << "[" << dataExtent[0];
for (size_t i = 1; i < dataExtent.size(); i++) {
ss << ", " << dataExtent[i];
}
ss << "]";
return ss.str();
}
};
class Tensor {
public:
enum ScalarType {
INVALID_SCALAR_TYPE,
INT8, INT32, INT64,
FP16, FP32, BF16
};
struct TensorOptions {
Device device_;
ScalarType dtype_;
Device device() const { return device_; }
ScalarType dtype() const { return dtype_; }
TensorOptions device(Device dev) const {
TensorOptions result(*this);
result.device_ = dev;
return result;
}
TensorOptions dtype(ScalarType type) const {
TensorOptions result(*this);
result.dtype_ = type;
return result;
}
};
static const std::map<ScalarType, size_t> scalarSize;
public:
TensorShape shape;
ScalarType scalarType;
std::shared_ptr<Buffer> buffer;
public:
bool valid() const { return shape.dataExtent.size() > 0; }
int size(int dim) const { return shape[dim]; }
bool is_contiguous() const { return shape.is_contiguous(); }
std::vector<int> sizes() const { return shape.dataExtent; }
bool is_cuda() const { return device().type == Device::CUDA; }
TensorOptions options() const { return TensorOptions{device(), dtype()}; }
int get_device() const { return device().idx; }
template<typename T>
T *data_ptr() { return reinterpret_cast<T*>(data_ptr()); }
template<typename T>
const T *data_ptr() const { return reinterpret_cast<const T*>(data_ptr()); }
const void *data_ptr() const { return buffer->getPtr<char>() + shape.offset * scalar_size(); }
void *data_ptr() { return buffer->getPtr<char>() + shape.offset * scalar_size(); }
Device device() const { return buffer->getDevice(); }
ScalarType scalar_type() const { return scalarType; }
ScalarType dtype() const { return scalar_type(); }
size_t stride(int dim) const { return shape.stride(dim); }
size_t numel() const { return shape.size(); }
size_t ndims() const { return shape.ndims(); }
size_t dim() const { return ndims(); }
size_t scalar_size() const { return scalarSize.at(scalarType); }
Tensor operator[](int idx) const {
assert(ndims() > 1);
Tensor result;
result.shape = std::vector<int>(this->shape.dataExtent.begin() + 1, this->shape.dataExtent.end());
size_t size = stride(0) * scalar_size();
result.buffer = std::make_shared<BufferView>(this->buffer, idx * size, size);
result.scalarType = this->scalarType;
return result;
}
template<typename T>
const T & at(const std::vector<int> &idx) const {
assert(ndims() == idx.size());
int64_t offset = 0;
for (size_t i = 0; i < ndims(); i++) {
offset += idx.at(i) * stride(i);
}
assert(offset >= 0 && offset < numel());
return this->data_ptr<T>()[offset];
}
template<typename T>
T & at(const std::vector<int> &idx) {
return const_cast<T &>(const_cast<const Tensor *>(this)->at<T>(idx));
}
Tensor slice(int dim, int from, int to) const {
assert(from <= to);
Tensor result;
result.buffer = this->buffer;
result.scalarType = this->scalarType;
result.shape = TensorShape(this->shape.dataExtent);
result.shape[dim] = to - from;
result.shape.dataStride.resize(result.shape.ndims());
for (int i = 0; i < result.shape.ndims(); i++) {
result.shape.dataStride[i] = this->shape.stride(i);
}
result.shape.offset = this->shape.offset + this->shape.stride(dim) * from;
return result;
}
Tensor transpose(int dim1, int dim2) const {
Tensor result;
result.buffer = this->buffer;
result.scalarType = this->scalarType;
result.shape = TensorShape(this->shape.dataExtent);
result.shape.dataStride.resize(result.shape.ndims());
for (int i = 0; i < result.shape.ndims(); i++) {
result.shape.dataStride[i] = this->shape.stride(i);
}
result.shape.offset = this->shape.offset;
std::swap(result.shape.dataExtent[dim1], result.shape.dataExtent[dim2]);
std::swap(result.shape.dataStride[dim1], result.shape.dataStride[dim2]);
return result;
}
Tensor view(TensorShape shape) const {
assert(shape.size() == this->shape.size());
assert(this->is_contiguous());
Tensor result;
result.buffer = this->buffer;
result.scalarType = this->scalarType;
result.shape = shape;
result.shape.offset = this->shape.offset;
return result;
}
Tensor reshape(TensorShape shape) const {
return view(shape);
}
// // NOT IMPLEMENTED!!! DONT USE
// Tensor transpose(int a, int b) const {
// throw std::runtime_error("Not implemented");
// }
Tensor &zero_() {
assert(this->is_contiguous());
checkCUDA(cudaMemset(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size()));
return *this;
}
Tensor &copy_(Tensor other) {
assert(this->is_contiguous());
assert(other.is_contiguous());
assert(this->shape.dataExtent == other.shape.dataExtent);
assert(this->dtype() == other.dtype());
assert((shape.offset + shape.size()) * scalar_size() <= buffer->getSize());
assert((other.shape.offset + shape.size()) * scalar_size() <= other.buffer->getSize());
if (shape.size() == 0) {
return *this;
}
if (this->device().type == Device::CPU && other.device().type == Device::CPU) {
memcpy(
data_ptr<char>(),
other.data_ptr<char>(),
shape.size() * scalar_size()
);
return *this;
}
lockBuffer(this->buffer, getCurrentCUDAStream());
lockBuffer(other.buffer, getCurrentCUDAStream());
checkCUDA(cudaMemcpyAsync(
data_ptr<char>(),
other.data_ptr<char>(),
shape.size() * scalar_size(),
getCopyKind(this->device(), other.device()),
getCurrentCUDAStream()
));
return *this;
}
// NOT IMPLEMENTED!!! DONT USE
template<typename T>
Tensor &fill_(T val) {
throw std::runtime_error("Not implemented");
return *this;
}
// NOT IMPLEMENTED!!! DONT USE
Tensor index(std::vector<std::any> whatever) {
throw std::runtime_error("Not implemented");
}
public:
static Tensor allocate(TensorShape shape, ScalarType scalarType, Device device, bool fill = false) {
Tensor result;
assert(shape.is_contiguous());
if (device.type == Device::CPU) {
result.buffer = std::make_shared<BufferMalloc>(shape.size() * scalarSize.at(scalarType));
} else if (device.type == Device::CUDA) {
// TODO: cross device allocate
result.buffer = std::make_shared<BufferCUDA>(shape.size() * scalarSize.at(scalarType));
} else {
assert(false);
}
result.scalarType = scalarType;
result.shape = shape;
if (fill) {
if (device.type == Device::CPU) {
memset(result.buffer->getPtr(), 0xCC, result.buffer->getSize());
} else if (device.type == Device::CUDA) {
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream()));
}
}
return result;
}
static Tensor empty(TensorShape shape, ScalarType scalarType, Device device) {
return allocate(shape, scalarType, device);
}
static Tensor empty_like(const Tensor &tensor) {
return empty(TensorShape(tensor.shape.dataExtent), tensor.scalarType, tensor.device());
}
static Tensor ones(TensorShape shape, ScalarType scalarType, Device device) {
Tensor result = allocate(shape, scalarType, device);
// FIXME FIXME FIXME
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 1, result.buffer->getSize(), getCurrentCUDAStream()));
return result;
}
static Tensor allocate_view(TensorShape shape, ScalarType scalarType, std::shared_ptr<Buffer> buffer, size_t offset = 0) {
Tensor result;
result.buffer = std::make_shared<BufferView>(buffer, offset, shape.size() * scalarSize.at(scalarType));
result.scalarType = scalarType;
result.shape = shape;
return result;
}
public:
Tensor copy(Device device) const {
if (!buffer) {
return *this;
}
Tensor result = allocate(this->shape.dataExtent, this->scalarType, device);
result.copy_(*this);
// lockBuffer(this->buffer, getCurrentCUDAStream());
// lockBuffer(result.buffer, getCurrentCUDAStream());
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream()));
// if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// } else if (this->device().type == Device::CUDA && device.type == Device::CPU) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// } else {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault, getCurrentCUDAStream()));
// }
return result;
}
// void copy_range(Tensor &dst, int dim, int lower_bound, int upper_bound) {
// if (upper_bound > shape[dim]) {
// upper_bound = shape[dim];
// }
// if (lower_bound >= upper_bound) {
// return;
// }
// auto shapeOut = this->shape;
// shapeOut[dim] = upper_bound - lower_bound;
// assert(dst.shape.data == shapeOut.data);
// checkCUDA(cudaMemcpy2DAsync(
// dst.
// ));
// }
private:
static cudaMemcpyKind getCopyKind(Device dst, Device src) {
if (src.type == Device::CPU && dst.type == Device::CUDA) {
return cudaMemcpyHostToDevice;
}
if (src.type == Device::CUDA && dst.type == Device::CPU) {
return cudaMemcpyDeviceToHost;
}
if (src.type == Device::CUDA && dst.type == Device::CUDA) {
return cudaMemcpyDeviceToDevice;
}
if (src.type == Device::CPU && dst.type == Device::CPU) {
return cudaMemcpyHostToHost;
}
return cudaMemcpyDefault;
}
static bool isAsyncBuffer(Buffer *buffer) {
return dynamic_cast<BufferCUDA *>(buffer);
}
static inline std::map<cudaStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers;
public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU completes
static void lockBuffer(std::shared_ptr<Buffer> buffer, cudaStream_t stream) {
if (!isAsyncBuffer(buffer.get())) {
lockedBuffers[stream].insert(buffer);
}
}
// we could unlock buffers after sync with GPU
static void unlockBuffers() {
lockedBuffers.clear();
}
static void unlockBuffers(cudaStream_t stream) {
lockedBuffers[stream].clear();
}
static void synchronizeDevice() {
checkCUDA(cudaDeviceSynchronize());
unlockBuffers();
}
static void synchronizeStream(cudaStream_t stream) {
checkCUDA(cudaStreamSynchronize(stream));
unlockBuffers(stream);
}
};
inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
{INT8, 1},
{INT32, 4},
{INT64, 8},
{FP16, 2},
{FP32, 4},
{BF16, 2},
};
struct TensorsProvider {
virtual ~TensorsProvider() {}
virtual bool contains(const std::string &key) const = 0;
virtual Tensor getTensor(const std::string &key) = 0;
};
\ No newline at end of file
#include "activation.h"
#include "kernels/activation_kernels.h"
Tensor Silu::forward(Tensor x) {
Tensor out = Tensor::empty_like(x);
silu(out, x);
return out;
}
Tensor GELU::forward(Tensor x) {
Tensor out = Tensor::empty_like(x);
gelu_new(out, x);
return out;
}
// Tensor SiluAndMul::forward(Tensor x) {
// int d = x.shape[-1] / 2;
// auto output_shape = x.shape;
// output_shape[-1] = d;
// Tensor out = Tensor::empty(output_shape, x.scalar_type(), x.device());
// silu_and_mul(out, x);
// return out;
// }
// Tensor SiluAndMulQuant::forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x);
// invoke_quant_fuse_sum(quantized_mlp_act_buffer, out, quantized_sum_buffer, quantized_scale_buffer);
// return out;
// }
// Tensor SiluAndMulQuant::forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
// Tensor out = SiluAndMul::forward(x);
// invoke_quant(quantized_mlp_act_buffer, out, quantized_scale_buffer, {});
// return out;
// }
#pragma once
#include "common.h"
#include "Tensor.h"
class Silu {
public:
static Tensor forward(Tensor x);
};
class GELU {
public:
static Tensor forward(Tensor x);
};
// class SiluAndMul {
// public:
// static Tensor forward(Tensor x);
// };
// class SiluAndMulQuant {
// public:
// static Tensor forward(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer, bool act_sum) {
// if (act_sum) {
// return forward_with_act_sum(x, quantized_mlp_act_buffer, quantized_scale_buffer, quantized_sum_buffer);
// } else {
// return forward_wo_act_sum(x, quantized_mlp_act_buffer, quantized_scale_buffer, quantized_sum_buffer);
// }
// }
// private:
// static Tensor forward_with_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
// static Tensor forward_wo_act_sum(Tensor x, Tensor quantized_mlp_act_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
// };
\ No newline at end of file
#pragma once
#include <cstddef>
#include <cassert>
#include <cmath>
#include <iostream>
#include <fstream>
#include <sstream>
#include <memory>
#include <source_location>
#include <vector>
#include <stack>
#include <map>
#include <unordered_map>
#include <set>
#include <any>
#include <variant>
#include <optional>
#include <chrono>
#include <functional>
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <spdlog/spdlog.h>
class CUDAError : public std::runtime_error {
public:
CUDAError(cudaError_t errorCode, std::source_location location)
: std::runtime_error(format(errorCode, location)), errorCode(errorCode), location(location) {}
public:
const cudaError_t errorCode;
const std::source_location location;
private:
static std::string format(cudaError_t errorCode, std::source_location location) {
return spdlog::fmt_lib::format("CUDA error: {} (at {}:{})",
cudaGetErrorString(errorCode), location.file_name(), location.line());
}
};
inline cudaError_t checkCUDA(cudaError_t retValue, const std::source_location location = std::source_location::current()) {
if (retValue != cudaSuccess) {
throw CUDAError(retValue, location);
}
return retValue;
}
inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue, const std::source_location location = std::source_location::current()) {
if (retValue != CUBLAS_STATUS_SUCCESS) {
throw std::runtime_error(spdlog::fmt_lib::format("CUBLAS error: {} (at {}:{})",
cublasGetStatusString(retValue), location.file_name(), location.line()));
}
return retValue;
}
inline thread_local std::stack<cudaStream_t> stackCUDAStreams;
inline cudaStream_t getCurrentCUDAStream() {
if (stackCUDAStreams.empty()) {
return 0;
}
return stackCUDAStreams.top();
}
inline cudaDeviceProp *getCurrentDeviceProperties() {
static thread_local cudaDeviceProp prop;
static thread_local bool propAvailable = false;
if (!propAvailable) {
int device;
checkCUDA(cudaGetDevice(&device));
checkCUDA(cudaGetDeviceProperties(&prop, device));
propAvailable = true;
}
return &prop;
}
template<typename T>
constexpr T ceilDiv(T a, T b) {
return (a + b - 1) / b;
}
struct CUBLASWrapper {
cublasHandle_t handle = nullptr;
CUBLASWrapper() {
checkCUBLAS(cublasCreate(&handle));
}
CUBLASWrapper(CUBLASWrapper &&) = delete;
CUBLASWrapper(const CUBLASWrapper &&) = delete;
~CUBLASWrapper() {
if (handle) {
checkCUBLAS(cublasDestroy(handle));
}
}
};
inline std::shared_ptr<CUBLASWrapper> getCUBLAS() {
static thread_local std::weak_ptr<CUBLASWrapper> inst;
std::shared_ptr<CUBLASWrapper> result = inst.lock();
if (result) {
return result;
}
result = std::make_shared<CUBLASWrapper>();
inst = result;
return result;
}
\ No newline at end of file
#pragma once
#include "common.h"
#include "Tensor.h"
class DebugContext {
public:
DebugContext() {
ctxs.insert(this);
}
DebugContext(const DebugContext &) = delete;
DebugContext(DebugContext &&) = delete;
~DebugContext() {
ctxs.erase(this);
}
std::map<std::string, Tensor> tensors;
static inline thread_local std::set<DebugContext *> ctxs;
};
#include "torch.h"
#include <ATen/cuda/CUDAContext.h>
using spdlog::fmt_lib::format;
template<typename To, typename Ti>
static To int_cast(Ti x) {
if (x < std::numeric_limits<To>::min() || x > std::numeric_limits<To>::max()) {
throw std::runtime_error("integer overflow");
}
return static_cast<To>(x);
}
Tensor from_torch(at::Tensor input) {
Tensor result;
const int ndims = int_cast<int>(input.ndimension());
for (int i = 0; i < ndims; i++) {
result.shape.dataExtent.push_back(int_cast<decltype(result.shape.dataExtent)::value_type>(input.size(i)));
result.shape.dataStride.push_back(int_cast<decltype(result.shape.dataStride)::value_type>(input.stride(i)));
}
static const std::map<at::ScalarType, Tensor::ScalarType> mapType = {
{ at::ScalarType::Byte, Tensor::INT8 },
{ at::ScalarType::Int, Tensor::INT32 },
{ at::ScalarType::Long, Tensor::INT64 },
{ at::ScalarType::Float, Tensor::FP32 },
{ at::ScalarType::Half, Tensor::FP16 },
{ at::ScalarType::BFloat16, Tensor::BF16 },
};
result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
return result;
}
at::Tensor to_torch(Tensor input) {
assert(input.is_contiguous());
std::vector<int64_t> shape;
for (size_t i = 0; i < input.ndims(); i++) {
shape.push_back(input.size(i));
}
static const std::map<Tensor::ScalarType, at::ScalarType> mapType = {
{ Tensor::INT8, at::ScalarType::Byte },
{ Tensor::INT32, at::ScalarType::Int },
{ Tensor::INT64, at::ScalarType::Long },
{ Tensor::FP32, at::ScalarType::Float },
{ Tensor::FP16, at::ScalarType::Half },
{ Tensor::BF16, at::ScalarType::BFloat16 },
};
c10::TensorOptions opts(mapType.at(input.scalar_type()));
if (input.device().type == Device::CPU) {
opts = opts.device("cpu");
} else {
opts = opts.device(format("cuda:{}", input.device().idx));
}
at::Tensor result = torch::empty(at::IntArrayRef(shape), opts);
from_torch(result).copy_(input);
return result;
}
TorchOpContext::TorchOpContext() {
stackCUDAStreams.push(at::cuda::getCurrentCUDAStream().stream());
}
TorchOpContext::~TorchOpContext() {
assert(stackCUDAStreams.top() == at::cuda::getCurrentCUDAStream().stream());
stackCUDAStreams.pop();
}
\ No newline at end of file
#pragma once
#include <torch/extension.h>
#include "common.h"
#include "Tensor.h"
class BufferTorchTensor : public Buffer {
public:
BufferTorchTensor(at::Tensor tensor) : tensor(std::move(tensor)) {
this->size = this->tensor.numel() * this->tensor.itemsize();
this->ptr = this->tensor.data_ptr();
this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU;
this->device.idx = this->tensor.get_device();
}
private:
at::Tensor tensor;
};
class TorchOpContext {
public:
TorchOpContext();
TorchOpContext(const TorchOpContext &) = delete;
TorchOpContext(TorchOpContext &&) = delete;
~TorchOpContext();
};
Tensor from_torch(at::Tensor input);
at::Tensor to_torch(Tensor input);
\ No newline at end of file
#include "activation_kernels_impl.cuh"
#include "activation_kernels.h"
#include "dispatch_utils.h"
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
void silu_and_mul(
Tensor& out, // [..., d]
Tensor& input) // [..., 2 * d]
{
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
// dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
// vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
// out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
// });
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] {
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
});
}
void invoke_dequant_silu_and_mul_quant(
Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate, const float scale_up, const float scale_out) {
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float, false><<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), d, scale_gate,
scale_up, scale_out);
}
void invoke_dequant_silu_and_mul_quant(
Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate, const float scale_up,
Tensor &scale_out, // [num_tokens]
Tensor &tmp // [..., d]
) {
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float*, true><<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(),
d, scale_gate, scale_up, scale_out.data_ptr<float>(), tmp.data_ptr<float>());
}
void silu(
Tensor& out, // [..., d]
Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::silu);
}
void gelu_new(
Tensor& out, // [..., d]
Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(
Tensor& out, // [..., d]
Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
\ No newline at end of file
#pragma once
#include "common.h"
#include "Tensor.h"
void silu(
Tensor& out, // [..., d]
Tensor& input);
void silu_and_mul(Tensor &out, // [..., d]
Tensor &input); // [..., 2 * d]
void gelu_new(Tensor &out, Tensor &input);
void gelu_fast(Tensor &out, Tensor &input);
void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate,
const float scale_up,
const float scale_out);
void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
Tensor &input, // [..., 2 * d]
const float scale_gate,
const float scale_up,
Tensor &scale_out, // [num_tokens]
Tensor &tmp // [num_tokens, d]
);
\ No newline at end of file
#include "utils.cuh"
#include "reduction_utils.cuh"
namespace vllm {
template <typename T> __device__ __forceinline__ T silu(const T &x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}
template<typename scalar_t>
__global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2 * d]
const int d) {
const int token_idx = blockIdx.x;
const int64_t token_idx_d = token_idx * int64_t(d);
const int64_t token_idx_2d = token_idx_d * 2;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx_2d + idx]);
const scalar_t y = __ldg(&input[token_idx_2d + d + idx]);
out[token_idx_d + idx] = silu(x) * y;
}
}
// dequant int32 input, apply silu and mul, then per token quant to int8
template <typename scale_type, bool use_per_token_quant>
__global__ void dequant_silu_and_mul_quant_kernel(
int8_t *__restrict__ out, // [..., d]
const int32_t *__restrict__ input, // [..., 2 * d]
const int d, const float scale_gate, const float scale_up,
scale_type scale_out, // [num_tokens]
float *__restrict__ tmp = nullptr // [num_tokens, d]
) {
const int token_idx = blockIdx.x;
if constexpr (use_per_token_quant) {
float amax_val = 0.0f;
const float zero = 0.0f;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x =
(float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y =
(float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
float t = silu(x) * y;
tmp[token_idx * d + idx] = t;
t = t > zero ? t : -t;
if (t > amax_val)
amax_val = t;
}
__shared__ float s_amax;
const float block_amax_val = blockReduceMax(amax_val);
if (threadIdx.x == 0) {
s_amax = block_amax_val;
scale_out[token_idx] = block_amax_val / 127.0f;
}
__syncthreads();
float tmp_scale = 127.0f / s_amax;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
out[token_idx * d + idx] =
float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
}
} else {
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const float x =
(float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
const float y =
(float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
}
}
}
} // namespace vllm
namespace vllm {
// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
} // namespace vllm
namespace vllm {
template <typename T> __device__ __forceinline__ T gelu_new_kernel(const T &x) {
const float x3 = (float)(x * x * x);
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T)0.5) * x * (((T)1.0) + t);
}
template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T &x) {
const float f = (float)x;
const T t =
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
}
} // namespace vllm
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#pragma once
#include <cuda_fp16.h>
#include <cstdint>
__forceinline__ __device__
void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
}
__forceinline__ __device__
void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
// *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->x));
// *reinterpret_cast<__nv_bfloat162 *>(&result->y) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->y));
// *reinterpret_cast<__nv_bfloat162 *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->w));
// return;
// uint4 result;
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 ((i4s >> 4) & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s >> 4), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 ((i4s >> 8) & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(i4s >> 8), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 ((i4s >> 12) & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(i4s >> 12), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// static constexpr uint32_t BF16_BIAS = 0xC308C308;
// This is the BF16 {-128, -128} represented as an integer, we do not need to map to [-8, 7]
static constexpr uint32_t BF16_BIAS = 0xC300C300;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_23
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_45
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_67
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment