"tests/vscode:/vscode.git/clone" did not exist on "94a948722f2be3834c00e5e158d277d623750c44"
Commit 3e38a2ea authored by wenjh's avatar wenjh
Browse files

[Workaround] Use bf16 lds to save fp32 input



quantize_transpose_vector_blockwise function use lds exceeding 64kb when
input type is fp32. But max size of lds in dcu is 64kb, thus we use lds
as bfp16 for workaround.
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 32184507
...@@ -263,16 +263,30 @@ void compare_scaling_factors(const std::string& name, const float* test, const f ...@@ -263,16 +263,30 @@ void compare_scaling_factors(const std::string& name, const float* test, const f
void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test, void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test,
const float* ref, const size_t rows, const float* ref, const size_t rows,
const size_t col_blocks) { const size_t col_blocks
#ifdef __HIP_PLATFORM_AMD__
, double atol = 0., double rtol = 0.
#endif
) {
const size_t test_stride = scale_align_stride(rows); const size_t test_stride = scale_align_stride(rows);
for (int i = 0; i < rows; ++i) { for (int i = 0; i < rows; ++i) {
for (int j = 0; j < col_blocks; ++j) { for (int j = 0; j < col_blocks; ++j) {
const int test_idx = i + test_stride * j; const int test_idx = i + test_stride * j;
const int ref_idx = i + rows * j; const int ref_idx = i + rows * j;
#ifdef __HIP_PLATFORM_AMD__
double t = static_cast<double>(static_cast<float>(test[test_idx]));
double r = static_cast<double>(static_cast<float>(ref[ref_idx]));
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
ASSERT_FALSE(mismatch)
<< "Error in " << name << std::endl
<< "Mismatch: " << t << " vs " << r << " at index " << test_idx
<< "," << ref_idx;
#else
ASSERT_FALSE(test[test_idx] != ref[ref_idx]) ASSERT_FALSE(test[test_idx] != ref[ref_idx])
<< "Error in " << name << std::endl << "Error in " << name << std::endl
<< "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx
<< "," << ref_idx; << "," << ref_idx;
#endif
} }
} }
} }
...@@ -411,18 +425,33 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, ...@@ -411,18 +425,33 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
float atol = 0.0; float atol = 0.0;
float rtol = 0.0; float rtol = 0.0;
#ifdef __HIP_PLATFORM_AMD__
double atol_scale = 0.0;
double rtol_scale = 0.0;
if(itype == DType::kFloat32)
{
atol_scale = 1e-5;
}
#endif
if (rowwise) { if (rowwise) {
compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); compareResults("output_c", output_c, ref_output.get(), true, atol, rtol);
compare_scaling_factors_one_dimensional_blocks("scale_inv", compare_scaling_factors_one_dimensional_blocks("scale_inv",
output_c.rowwise_cpu_scale_inv_ptr<float>(), output_c.rowwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv.get(), rows, blocks_x); ref_scale_inv.get(), rows, blocks_x
#ifdef __HIP_PLATFORM_AMD__
, atol_scale, rtol_scale
#endif
);
} }
if (colwise) { if (colwise) {
compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol); compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol);
compare_scaling_factors_one_dimensional_blocks("scale_inv_t", compare_scaling_factors_one_dimensional_blocks("scale_inv_t",
output_c.columnwise_cpu_scale_inv_ptr<float>(), output_c.columnwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv_t.get(), cols, blocks_x_t); ref_scale_inv_t.get(), cols, blocks_x_t
#ifdef __HIP_PLATFORM_AMD__
, atol_scale, rtol_scale
#endif
);
} }
} }
......
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
#include <algorithm> #include <algorithm>
#include <cfloat> #include <cfloat>
#ifndef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
#include <condition_variable>
#include <type_traits>
#else
#include <cuda/barrier> #include <cuda/barrier>
#endif #endif
#include <utility> #include <utility>
...@@ -165,7 +168,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -165,7 +168,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}; };
extern __shared__ char smem_base[]; extern __shared__ char smem_base[];
#ifdef __HIP_PLATFORM_AMD__
using HipSMemVec = Vec<std::conditional_t<std::is_same_v<IType, float>, __hip_bfloat16, IType>, kNVecSMem>;
HipSMemVec* smem = reinterpret_cast<HipSMemVec*>(&smem_base[0]);
#else
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]); SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
#endif
// Step 1: Load input to shared memory // Step 1: Load input to shared memory
{ {
...@@ -199,7 +207,29 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -199,7 +207,29 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i; int c = c_s + i;
int r = r_s; int r = r_s;
#ifdef __HIP_PLATFORM_AMD__
if constexpr (std::is_same_v<IType, float>)
{
#pragma unroll
for(int j = 0; j < kNVecSMem; ++j)
{
uint32_t raw_val = *reinterpret_cast<const uint32_t*>(&input_vec.smem_type.data.elt[i].data.elt[j]);
// [Workaround] Under certain critical conditions, scale will be 2 * ref_scale because of float -> bfloat16.
// We use carry over here to avoid this issue.
if(pow_2_scaling && (raw_val & 0x0000FFFF))
{
raw_val |= 0x00010000;
}
smem[r * kSMemCol + c].data.elt[j] = *reinterpret_cast<const float*>(&raw_val);
}
}
else
{
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
}
#else
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i]; smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
#endif
} }
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case) // Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g; input_g += stride_g;
...@@ -241,7 +271,22 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -241,7 +271,22 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i; int c = c_s + i;
int r = r_s; int r = r_s;
#ifdef __HIP_PLATFORM_AMD__
if constexpr (std::is_same_v<IType, float>)
{
#pragma unroll
for(int j = 0; j < kNVecSMem; ++j)
{
smem_vec[i].data.elt[j] = smem[r * kSMemCol + c].data.elt[j];
}
}
else
{
smem_vec[i] = smem[r * kSMemCol + c];
}
#else
smem_vec[i] = smem[r * kSMemCol + c]; smem_vec[i] = smem[r * kSMemCol + c];
#endif
} }
// Step 2.2: Compute local amax // Step 2.2: Compute local amax
CType amax = 0; CType amax = 0;
...@@ -257,7 +302,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -257,7 +302,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll #pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
const float other_amax = __shfl_down(amax, delta, kThreadsPerWarp); const float other_amax = __shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else #else
const float other_amax = __shfl_down_sync(mask, amax, delta); const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif #endif
...@@ -266,7 +311,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -266,7 +311,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
amax = fmaxf(amax, other_amax); amax = fmaxf(amax, other_amax);
} }
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
amax = __shfl(amax, src_lane, kThreadsPerWarp); amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else #else
amax = __shfl_sync(mask, amax, src_lane); amax = __shfl_sync(mask, amax, src_lane);
#endif #endif
...@@ -340,7 +385,22 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -340,7 +385,22 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecOut; ++i) { for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i; int r = r_s + i;
int c = c_s; int c = c_s;
#ifdef __HIP_PLATFORM_AMD__
if constexpr (std::is_same_v<IType, float>)
{
#pragma unroll
for(int j = 0; j < kNVecSMem; ++j)
{
smem_vec[i].data.elt[j] = smem[r * kSMemCol + c].data.elt[j];
}
}
else
{
smem_vec[i] = smem[r * kSMemCol + c];
}
#else
smem_vec[i] = smem[r * kSMemCol + c]; smem_vec[i] = smem[r * kSMemCol + c];
#endif
} }
#pragma unroll #pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
...@@ -354,7 +414,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -354,7 +414,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll #pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
const float other_amax = __shfl_down(amax, delta, kThreadsPerWarp); const float other_amax = __shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else #else
const float other_amax = __shfl_down_sync(mask, amax, delta); const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif #endif
...@@ -363,7 +423,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -363,7 +423,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
amax = fmaxf(amax, other_amax); amax = fmaxf(amax, other_amax);
} }
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
amax = __shfl(amax, src_lane, kThreadsPerWarp); amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else #else
amax = __shfl_sync(mask, amax, src_lane); amax = __shfl_sync(mask, amax, src_lane);
#endif #endif
...@@ -489,7 +549,12 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -489,7 +549,12 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned, full_tile, kAligned,
#ifdef __HIP_PLATFORM_AMD__
using HipSMemType = std::conditional_t<std::is_same_v<InputType, float>, hip_bfloat16, InputType>;
size_t smem_bytes = kSMemSize * sizeof(HipSMemType);
#else
size_t smem_bytes = kSMemSize * sizeof(InputType); size_t smem_bytes = kSMemSize * sizeof(InputType);
#endif
// shared memory must be requested up // shared memory must be requested up
if (smem_bytes >= 48 * 1024) { if (smem_bytes >= 48 * 1024) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment