"vscode:/vscode.git/clone" did not exist on "fccb6d9d96483833558c395e41532941bbab4a9c"
Commit 0b0a70a5 authored by yuguo's avatar yuguo
Browse files
parents e80f260d 3ce226ae
...@@ -171,7 +171,8 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], ...@@ -171,7 +171,8 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
CType elt = step_dbias.data.elt[j]; CType elt = step_dbias.data.elt[j];
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
elt = __shfl(elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl(elt, dbias_shfl_src_lane, THREADS_PER_WARP); // shuffle data in a warp
__syncthreads();
#else #else
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
#endif #endif
......
...@@ -91,16 +91,25 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( ...@@ -91,16 +91,25 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
local_output_c.store_to(&output_c[row * row_length + col]); local_output_c.store_to(&output_c[row * row_length + col]);
} }
} }
#ifndef __HIP_PLATFORM_AMD__
// Copy from registers to shared memory to global memory // Copy from registers to shared memory to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1]; __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#else
constexpr size_t inner_dim = THREADS_PER_WARP + 1;
constexpr size_t outter_dim = THREADS_PER_WARP;
__shared__ OVecT shared_output_t[outter_dim * inner_dim];
#endif
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
#ifndef __HIP_PLATFORM_AMD__
shared_output_t[j1][i1] = local_output_t[j2][iter]; shared_output_t[j1][i1] = local_output_t[j2][iter];
#else
shared_output_t[j1 * inner_dim + i1] = local_output_t[j2][iter];
#endif
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
...@@ -109,7 +118,11 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( ...@@ -109,7 +118,11 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
const size_t j1 = tidy + iter * bdimy; const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out; const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2; const size_t col = tile_col + j1 * nvec_in + j2;
#ifndef __HIP_PLATFORM_AMD__
shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]); shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]);
#else
shared_output_t[j1 * inner_dim + i1].store_to(&output_t[col * num_rows + row]);
#endif
} }
__syncthreads(); __syncthreads();
} }
......
...@@ -91,7 +91,8 @@ inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_O ...@@ -91,7 +91,8 @@ inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_O
for (unsigned int j = 0; j < NVEC_IN; ++j) { for (unsigned int j = 0; j < NVEC_IN; ++j) {
CType elt = step_dbias.data.elt[j]; CType elt = step_dbias.data.elt[j];
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
elt = __shfl(elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl(elt, dbias_shfl_src_lane, THREADS_PER_WARP); // shuffle data in a warp
__syncthreads();
#else #else
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment