"vscode:/vscode.git/clone" did not exist on "080438477f319149db4f09f3a8835dde23609f7a"
Commit 0b0a70a5 authored by yuguo's avatar yuguo
Browse files
parents e80f260d 3ce226ae
......@@ -171,7 +171,8 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
for (unsigned int j = 0; j < nvec_in; ++j) {
CType elt = step_dbias.data.elt[j];
#ifdef __HIP_PLATFORM_AMD__
elt = __shfl(elt, dbias_shfl_src_lane); // shuffle data in a warp
elt = __shfl(elt, dbias_shfl_src_lane, THREADS_PER_WARP); // shuffle data in a warp
__syncthreads();
#else
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
#endif
......
......@@ -91,16 +91,25 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
local_output_c.store_to(&output_c[row * row_length + col]);
}
}
#ifndef __HIP_PLATFORM_AMD__
// Copy from registers to shared memory to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#else
constexpr size_t inner_dim = THREADS_PER_WARP + 1;
constexpr size_t outter_dim = THREADS_PER_WARP;
__shared__ OVecT shared_output_t[outter_dim * inner_dim];
#endif
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
#ifndef __HIP_PLATFORM_AMD__
shared_output_t[j1][i1] = local_output_t[j2][iter];
#else
shared_output_t[j1 * inner_dim + i1] = local_output_t[j2][iter];
#endif
}
__syncthreads();
#pragma unroll
......@@ -109,7 +118,11 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2;
#ifndef __HIP_PLATFORM_AMD__
shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]);
#else
shared_output_t[j1 * inner_dim + i1].store_to(&output_t[col * num_rows + row]);
#endif
}
__syncthreads();
}
......
......@@ -91,7 +91,8 @@ inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_O
for (unsigned int j = 0; j < NVEC_IN; ++j) {
CType elt = step_dbias.data.elt[j];
#ifdef __HIP_PLATFORM_AMD__
elt = __shfl(elt, dbias_shfl_src_lane); // shuffle data in a warp
elt = __shfl(elt, dbias_shfl_src_lane, THREADS_PER_WARP); // shuffle data in a warp
__syncthreads();
#else
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment