Unverified Commit 070c45bb authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

docs: add the docstrings for v1.0.0 (#656)

* add v2 flux examples

* add the docs

* add docs

* update

* finished ops

* add ops

* update

* update

* update

* update

* update

* update

* update

* update docstrings

* update

* update

* update

* update

* update

* update

* update

* finished the api docs

* update

* update
parent e0392e42
"""
Python wrappers for Nunchaku's quantization operations.
This module provides Python wrappers for Nunchaku's high-performance SVDQuant quantization CUDA kernels.
"""
import torch
......@@ -18,22 +18,49 @@ def svdq_quantize_w4a4_act_fuse_lora_cuda(
fuse_glu: bool = False,
fp4: bool = False,
pad_size: int = 256,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
This function wraps the high-performance CUDA kernel for SVDQuant W4A4 quantized GEMM.
Notation
--------
M : int
Batch size (number of input samples).
K : int
Number of input channels (feature dimension).
N : int
Number of output channels.
G : int
Number of groups. 64 for INT4 and 16 for NVFP4.
R : int
Rank of the low-rank branch.
Quantizes activations and computes LoRA down-projection using SVDQuant W4A4 CUDA kernel.
Parameters
----------
input : torch.Tensor, shape (M, K), dtype bfloat16/float16
Input activations.
output : torch.Tensor or None, shape (M_pad, K // 2), dtype uint8, optional
Packed output tensor for quantized activations. Allocated if None.
oscales : torch.Tensor or None, shape (K // G, M_pad), dtype float8_e4m3fn for NVFP4 or input dtype for INT4, optional
Output scales tensor. Allocated if None.
lora_down : torch.Tensor or None, shape (K, R), dtype bfloat16/float16, optional
Packed LoRA down-projection weights.
lora_act_out : torch.Tensor or None, shape (M_pad, R), dtype float32, optional
Packed output tensor for LoRA activations. Allocated if None.
smooth : torch.Tensor or None, optional, dtype bfloat16/float16
Smoothing factor for quantization.
fuse_glu : bool, default=False
If True, fuse GLU activation.
fp4 : bool, default=False
If True, use NVFP4 quantization; else INT4.
pad_size : int, default=256
Pad batch size to a multiple of this value for efficient CUDA execution.
Returns
-------
output : torch.Tensor, shape (M_pad, K // 2), dtype uint8
Packed quantized activations.
oscales : torch.Tensor, shape (K // G, M_pad), dtype float8_e4m3fn for NVFP4 or input dtype for INT4
Output scales.
lora_act_out : torch.Tensor, shape (M_pad, R), dtype float32
Packed LoRA activation output.
Notes
-----
Notations:
- M: batch size
- K: input channels
- R: LoRA rank
- G: group size (64 for INT4, 16 for NVFP4)
- M_pad: padded batch size = ceil(M / pad_size) * pad_size
"""
batch_size, channels = input.shape
rank = lora_down.shape[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment