Commit 1f9c104b authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents 2b1428ff 8a03ff34
......@@ -263,30 +263,16 @@ void compare_scaling_factors(const std::string& name, const float* test, const f
void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test,
const float* ref, const size_t rows,
const size_t col_blocks
#ifdef __HIP_PLATFORM_AMD__
, double atol = 0., double rtol = 0.
#endif
) {
const size_t col_blocks) {
const size_t test_stride = scale_align_stride(rows);
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int test_idx = i + test_stride * j;
const int ref_idx = i + rows * j;
#ifdef __HIP_PLATFORM_AMD__
double t = static_cast<double>(static_cast<float>(test[test_idx]));
double r = static_cast<double>(static_cast<float>(ref[ref_idx]));
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
ASSERT_FALSE(mismatch)
<< "Error in " << name << std::endl
<< "Mismatch: " << t << " vs " << r << " at index " << test_idx
<< "," << ref_idx;
#else
ASSERT_FALSE(test[test_idx] != ref[ref_idx])
<< "Error in " << name << std::endl
<< "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx
<< "," << ref_idx;
#endif
}
}
}
......@@ -425,33 +411,17 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
float atol = 0.0;
float rtol = 0.0;
#ifdef __HIP_PLATFORM_AMD__
double atol_scale = 0.0;
double rtol_scale = 0.0;
if(itype == DType::kFloat32)
{
atol_scale = 1e-5;
}
#endif
if (rowwise) {
compareResults("output_c", output_c, ref_output.get(), true, atol, rtol);
compare_scaling_factors_one_dimensional_blocks("scale_inv",
output_c.rowwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv.get(), rows, blocks_x
#ifdef __HIP_PLATFORM_AMD__
, atol_scale, rtol_scale
#endif
);
ref_scale_inv.get(), rows, blocks_x);
}
if (colwise) {
compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol);
compare_scaling_factors_one_dimensional_blocks("scale_inv_t",
output_c.columnwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv_t.get(), cols, blocks_x_t
#ifdef __HIP_PLATFORM_AMD__
, atol_scale, rtol_scale
#endif
);
ref_scale_inv_t.get(), cols, blocks_x_t);
}
}
......
......@@ -171,7 +171,11 @@ class BlockwiseQuantizerReference:
qx = x_tiled * scale.reshape(M, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
if quant_dtype == torch.int8:
qx = torch.round(qx)
positive_mask = qx >= 0
negative_mask = ~positive_mask
pos_part = torch.where(positive_mask, torch.floor(qx + 0.5), 0)
neg_part = torch.where(negative_mask, torch.ceil(qx - 0.5), 0)
qx = pos_part + neg_part
qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K)
return qx, scale_inv
......
......@@ -4,7 +4,7 @@
from typing import Tuple
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
def scale_from_amax_tensor(
x_dtype: torch.dtype,
......@@ -48,6 +48,10 @@ def scale_from_amax_tensor(
# No subnormals and zero.
assert (exp > -127).all()
unity = torch.tensor([1.0], device=exp.device)
if IS_HIP_EXTENSION:
host_scale = torch.ldexp(unity.cpu(), exp.cpu())
scale = host_scale.to(exp.device)
else:
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
......
......@@ -273,7 +273,7 @@ def check_quantization_block_tiling_versus_reference(
)
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0 if quant_dtype != torch.int8 else 1.0, rtol=0.0)
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
# Zero out values that are don't care values
# Scale format has padding.
scale_mask = torch.ones(
......@@ -283,7 +283,7 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult(qx, scale_mask, None, None), tile_size
).scale
sx = sx * scale_mask
torch.testing.assert_close(sx, sx_ref, atol=0.0 if x_dtype != torch.float32 else 1e-5, rtol=0.0 if x_dtype != torch.float32 else 5e-5)
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
assert qx_t is not None
......@@ -299,8 +299,8 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult(qx_t, scale_mask, None, None), tile_size
).scale
sx_t = sx_t * scale_mask
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0 if quant_dtype != torch.int8 else 1.0, rtol=0.0 if x_dtype != torch.float32 else 2.5e-1)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0 if x_dtype != torch.float32 else 1e-5, rtol=0.0 if x_dtype != torch.float32 else 5e-5)
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)
else:
# should be None
assert qx_t is None and qx_t_ref is None
......
......@@ -187,8 +187,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// Step 3: Store cast output
CType scale_data = block_tile_scale;
OType scaled_elt =
OType scaled_elt = 0;
if constexpr(std::is_same_v<OType, int8_t>) {
scaled_elt =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data))));
}
else {
scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
}
tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment