Unverified Commit b1f80b8a authored by Mohamed Hisham's avatar Mohamed Hisham Committed by GitHub
Browse files

[CUDA] Branchless NF4/FP4 kDequantizeBlockwise kernel for faster dequantization (#1746)

* Added branchless LUT-based dequantization for FP4 and NF4

* Added extra command line options to control reproducibility

* Restore FP4 quantization/dequantization order
parent c9bce2b4
......@@ -21,6 +21,9 @@ options:
--batches BATCHES [BATCHES ...]
--input-length INPUT_LENGTH
--out-dir OUT_DIR
--iterations ITERATIONS
--warmup-runs WARMUP_RUNS
--output-length OUTPUT_LENGTH
"""
import argparse
......@@ -30,6 +33,9 @@ from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, Proce
from optimum_benchmark.logging_utils import setup_logging
import torch
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8
WEIGHTS_CONFIGS = {
......@@ -73,9 +79,8 @@ WEIGHTS_CONFIGS = {
},
}
if __name__ == "__main__":
setup_logging(level="INFO")
def parse_args():
parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool")
parser.add_argument("model_id", type=str, help="The model checkpoint to use.")
......@@ -98,20 +103,36 @@ if __name__ == "__main__":
parser.add_argument("--out-dir", type=str, default="reports")
args = parser.parse_args()
parser.add_argument("--iterations", type=int, default=10, help="Number of iterations for each benchmark run")
parser.add_argument(
"--warmup-runs", type=int, default=10, help="Number of warmup runs to discard before measurement"
)
parser.add_argument(
"--output-length",
type=int,
default=64,
help="If set, `max_new_tokens` and `min_new_tokens` will be set to this value.",
)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
return parser.parse_args()
for batch_size in args.batches:
print(f"Benchmarking batch size: {batch_size}")
for config in args.configs:
launcher_config = ProcessConfig(device_isolation=True, start_method="spawn")
def run_benchmark(args, config, batch_size):
launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="warn", start_method="spawn")
scenario_config = InferenceConfig(
latency=True,
memory=True,
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
iterations=args.iterations,
warmup_runs=args.warmup_runs,
# set duration to 0 to disable the duration-based stopping criterion
# this is IMPORTANT to ensure that all benchmarks run the same number of operations, regardless of hardware speed/bottlenecks
duration=0,
# for consistent results, set a fixed min and max for output tokens
generate_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length},
forward_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length},
)
backend_config = PyTorchConfig(
device="cuda",
device_ids="0",
......@@ -120,15 +141,35 @@ if __name__ == "__main__":
model=args.model_id,
**WEIGHTS_CONFIGS[config],
)
test_name = (
f"benchmark-{config}"
f"-bsz-{batch_size}"
f"-isz-{args.input_length}"
f"-osz-{args.output_length}"
f"-iter-{args.iterations}"
f"-wrmup-{args.warmup_runs}"
)
benchmark_config = BenchmarkConfig(
name=f"benchmark-{config}-bsz{batch_size}",
name=test_name,
scenario=scenario_config,
launcher=launcher_config,
backend=backend_config,
)
out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json"
out_path = out_dir / (test_name + ".json")
print(f"[{test_name}] Starting:")
benchmark_report = Benchmark.launch(benchmark_config)
benchmark_report.log()
benchmark_report.save_json(out_path)
if __name__ == "__main__":
setup_logging(level="INFO")
args = parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for batch_size in args.batches:
for config in args.configs:
run_benchmark(args, config, batch_size)
......@@ -21,23 +21,34 @@
#define NUM 4
#define NUM_BLOCK 4096
__device__ static float nf4_data[16] = {
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0
__device__ static float fp4_dequantization_lut[8] = {
0.0f, // 0b000
0.005208333333f, // 0b001
0.66666667f, // 0b010
1.0f, // 0b011
0.33333333f, // 0b100
0.5f, // 0b101
0.16666667f, // 0b110
0.25f // 0b111
};
__device__ static float nf4_dequantization_lut[16] = {
-1.0f, // 0b0000
-0.6961928009986877f, // 0b0001
-0.5250730514526367f, // 0b0010
-0.39491748809814453f, // 0b0011
-0.28444138169288635f, // 0b0100
-0.18477343022823334f, // 0b0101
-0.09105003625154495f, // 0b0110
0.0f, // 0b0111
0.07958029955625534f, // 0b1000
0.16093020141124725f, // 0b1001
0.24611230194568634f, // 0b1010
0.33791524171829224f, // 0b1011
0.44070982933044434f, // 0b1100
0.5626170039176941f, // 0b1101
0.7229568362236023f, // 0b1110
1.0f // 0b1111
};
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
......@@ -51,27 +62,9 @@ __device__ float atomicMax(float* address, float val) {
return __int_as_float(old);
}
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) {
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if ((val & 0b0100) == 4) // 0
if ((val & 0b0010) == 2) // 01
if ((val & 0b0001) == 1) // 111
return 0.25000000f * absmax * sign; // 1111
else
return 0.16666667f * absmax * sign; // 1110
else if ((val & 0b0001) == 1) // 110
return 0.50000000f * absmax * sign; // 1101
else
return 0.33333333f * absmax * sign; // 1100
else if ((val & 0b0010) == 2) // 10
if ((val & 0b0001) == 1) // 101
return 1.00000000f * absmax * sign; // 1011
else
return 0.66666667f * absmax * sign; // 1010
else if ((val & 0b0001) == 1) // 100
return 5.208333333e-03f * absmax * sign; // 1001
else
return 0.00000000f * absmax * sign; // 1000
__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) {
float sign = 1.0f - 2 * ((val & 0b1000) >> 3);
return fp4_dequantization_lut[val & 0b111] * sign;
}
__device__ unsigned char dQuantizeFP4(float x) {
......@@ -118,51 +111,7 @@ __device__ unsigned char dQuantizeFP4(float x) {
return 0b0000 + sign;
}
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) {
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if ((val & 0b1000) == 8)
if ((val & 0b0100) == 4) // 1
if ((val & 0b0010) == 2) // 11
if ((val & 0b0001) == 1) // 111
return 1.0f;
else
return 0.7229568362236023f;
else if ((val & 0b0001) == 1) // 110
return 0.5626170039176941f;
else
return 0.44070982933044434f;
else if ((val & 0b0010) == 2) // 10
if ((val & 0b0001) == 1) // 101
return 0.33791524171829224f;
else
return 0.24611230194568634f;
else if ((val & 0b0001) == 1) // 100
return 0.16093020141124725f;
else
return 0.07958029955625534f;
else if ((val & 0b0100) == 4) // 0
if ((val & 0b0010) == 2) // 01
if ((val & 0b0001) == 1) // 011
return 0.0f;
else
return -0.09105003625154495f;
else if ((val & 0b0001) == 1) // 010
return -0.18477343022823334f;
else
return -0.28444138169288635f;
else if ((val & 0b0010) == 2) // 00
if ((val & 0b0001) == 1) // 001
return -0.39491748809814453f;
else
return -0.5250730514526367f;
else if ((val & 0b0001) == 1) // 000
return -0.6961928009986877f;
else
return -1.0f;
}
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; }
__device__ unsigned char dQuantizeNF4(float x) {
......@@ -510,8 +459,8 @@ __global__ void
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max;
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max;
}
break;
case NF4:
......@@ -2352,7 +2301,7 @@ __global__ void kgemm_4bit_inference(
#pragma unroll 16
for (int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i];
quant_map[i] = nf4_dequantization_lut[i];
//__shared__ T quant_map[16*160];
T local_A[2];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment