Unverified Commit 70bd26e8 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fused rope compute in fp32 (#645)



Fused rope computation in fp32
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8641ab77
...@@ -28,12 +28,12 @@ __device__ void fused_rope_block_forward( ...@@ -28,12 +28,12 @@ __device__ void fused_rope_block_forward(
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
scalar_t v_src = src[offset_src]; float v_src = src[offset_src];
scalar_t v_src_rotate = (d_id + d2 / 2 < d2) float v_src_rotate = (d_id + d2 / 2 < d2)
? -src[offset_src + (d2 / 2) * stride_d] ? -static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: src[offset_src + (d2 / 2 - d2) * stride_d]; : static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
dst[offset_dst] = dst[offset_dst] =
v_src * (scalar_t)v_cos + v_src_rotate * (scalar_t)v_sin; v_src * v_cos + v_src_rotate * v_sin;
} }
} }
...@@ -61,16 +61,16 @@ __device__ void fused_rope_block_backward( ...@@ -61,16 +61,16 @@ __device__ void fused_rope_block_backward(
int s_id = blockIdx.x; int s_id = blockIdx.x;
#pragma unroll #pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
scalar_t v_cos = cosf(freqs[s_id * d2 + d_id]); float v_cos = cosf(freqs[s_id * d2 + d_id]);
scalar_t v_sin = (d_id + d2 / 2 < d2) float v_sin = (d_id + d2 / 2 < d2)
? sinf(freqs[s_id * d2 + d_id + d2 / 2]) ? sinf(freqs[s_id * d2 + d_id + d2 / 2])
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
#pragma unroll #pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
scalar_t v_src = src[offset_src]; float v_src = src[offset_src];
scalar_t v_src_rotate = (d_id + d2 / 2 < d2) float v_src_rotate = (d_id + d2 / 2 < d2)
? src[offset_src + (d2 / 2) * stride_d] ? src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d]; : src[offset_src + (d2 / 2 - d2) * stride_d];
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment