Unverified Commit 19fe95ac authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Merge pull request #1721 from Mhmd-Hisham/quantization-packing-bug-fix

[CUDA] Fixing quantization uint8 packing bug for NF4 and FP4
parents 42653921 639f8c05
...@@ -431,7 +431,6 @@ __global__ void kQuantizeBlockwise( ...@@ -431,7 +431,6 @@ __global__ void kQuantizeBlockwise(
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
} }
unsigned char packed_4bit = 0;
switch (DATA_TYPE) { switch (DATA_TYPE) {
case General8bit: case General8bit:
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
...@@ -445,17 +444,15 @@ __global__ void kQuantizeBlockwise( ...@@ -445,17 +444,15 @@ __global__ void kQuantizeBlockwise(
case FP4: case FP4:
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) { for (int j = 0; j < NUM_PER_TH / 2; j++) {
packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
qvals[j] = packed_4bit;
} }
break; break;
case NF4: case NF4:
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) { for (int j = 0; j < NUM_PER_TH / 2; j++) {
packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
qvals[j] = packed_4bit;
} }
break; break;
} }
......
...@@ -1125,21 +1125,52 @@ class TestQuantize4BitFunctional: ...@@ -1125,21 +1125,52 @@ class TestQuantize4BitFunctional:
# With larger block sizes, we can expect this to blow up. # With larger block sizes, we can expect this to blow up.
# At blocksize>=1024, don't even bother looking at relerr. # At blocksize>=1024, don't even bother looking at relerr.
if blocksize <= 64: #
assert err.item() < 0.1 # Actually, the above is not true anymore after fixing the integer packing bug.
assert relerr.item() < 0.28 # The following values were taken from averaging 1k samples per test configuration after fixing the bug.
elif blocksize <= 256: error_dict = dict()
assert err.item() < 0.11 error_dict["fp4"] = dict()
assert relerr.item() < 0.30 error_dict["nf4"] = dict()
elif blocksize <= 512: error_dict["fp4"]["err"] = {
assert err.item() < 0.12 64: 0.096545,
assert relerr.item() < 0.31 128: 0.102947,
elif quant_type == "fp4": 256: 0.108685,
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56 512: 0.114087,
assert err.item() < 0.08 + math.log2(blocksize) * 4e-2 1024: 0.119312,
else: 2048: 0.124460,
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 4096: 0.129573,
assert err.item() < math.log2(blocksize) * 8e-2 }
error_dict["fp4"]["rel_err"] = {
64: 0.260130,
128: 0.275734,
256: 0.289842,
512: 0.302852,
1024: 0.314982,
2048: 0.326402,
4096: 0.337228,
}
error_dict["nf4"]["err"] = {
64: 0.072792,
128: 0.076835,
256: 0.080326,
512: 0.083535,
1024: 0.086603,
2048: 0.089592,
4096: 0.092537,
}
error_dict["nf4"]["rel_err"] = {
64: 0.203299,
128: 0.215252,
256: 0.226044,
512: 0.236021,
1024: 0.245365,
2048: 0.254146,
4096: 0.262457,
}
assert err < error_dict[quant_type]["err"][blocksize] + 1e-3
assert relerr < error_dict[quant_type]["rel_err"][blocksize] + 1e-3
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment