Commit 531152d9 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

minor fixes

parent b1a83375
...@@ -32,11 +32,9 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); ...@@ -32,11 +32,9 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <> template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; }
template <> template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); }
template <> template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
...@@ -250,6 +248,8 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -250,6 +248,8 @@ __global__ void scaled_masked_softmax_warp_backward(
// load data from global memory // load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count; int batch_element_count = (i >= local_batches) ? 0 : element_count;
......
...@@ -32,21 +32,22 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); ...@@ -32,21 +32,22 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <> template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; } __device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; }
template <> template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); } __device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = 0; }
template <> template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <> template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) { int log2_ceil(int value) {
int log2_value = 0; int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value; while ((1 << log2_value) < value) ++log2_value;
...@@ -199,7 +200,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -199,7 +200,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
} }
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out); copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) { } else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE) copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
} else { } else {
break; break;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment