Commit c5167eb7 authored by zhangyue's avatar zhangyue
Browse files

issue/390: kunlun p800 causal softmax

parent c35920e2
...@@ -114,6 +114,29 @@ inline __device__ bfloat16_t atomicAdd<bfloat16_t>(__shared_ptr__ bfloat16_t *pt ...@@ -114,6 +114,29 @@ inline __device__ bfloat16_t atomicAdd<bfloat16_t>(__shared_ptr__ bfloat16_t *pt
return old; return old;
} }
/**
* @brief atomicMax for kunlun xpu
* @param ptr: pointer to shared memory
* @param value: value to compare
*/
template <typename T>
inline __device__ T atomicMax(__shared_ptr__ T *ptr, T value) {
ticket_lock_mix();
T old = loadsm(ptr);
if constexpr (std::is_same<T, bfloat16_t>::value) {
float of = __bfloat162float(old);
float vf = __bfloat162float(value);
float maxf = fmax(of, vf);
bfloat16_t max = __float2bfloat16_rn(maxf);
*ptr = max;
} else {
*ptr = fmax(old, value);
}
mfence_sm();
ticket_unlock_mix();
return old;
}
/** /**
* @brief Get index of broadcasted input * @brief Get index of broadcasted input
* flat_index: flatten index of output tensor * flat_index: flatten index of output tensor
...@@ -156,6 +179,75 @@ inline __device__ int indexToOffset( ...@@ -156,6 +179,75 @@ inline __device__ int indexToOffset(
return res; return res;
} }
/**
* @brief Get max of a array of local mem
* @param data: pointer to local memory
* @param len: length of array
* @return max value
*/
template <typename T>
__inline__ __device__ T max(const T *data_ptr, size_t len) {
T max_val = data_ptr[0];
for (size_t i = 0; i < len; ++i) {
max_val = fmax(max_val, data_ptr[i]);
}
return max_val;
}
// Use simd vector instruction to calculate max of a half array
template <>
__inline__ __device__ half max(const half *data_ptr, size_t len) {
int remain = len % 32;
int offset_last = len - remain;
half res = data_ptr[0];
for (int i = offset_last; i < len; i++) {
res = fmax(res, *(data_ptr + i));
}
mfence();
if (offset_last != 0) {
__local__ half acc_buf[32];
float16x32_t v_mv = vload_lm_float16x32_mz(data_ptr);
// for every 16 float data
for (int i = 32; i < offset_last; i += 32) {
float16x32_t v_0 = vload_lm_float16x32_mz(data_ptr + i);
v_mv = vvmax_float16x32_mz(v_mv, v_0);
}
vstore_lm_float16x32_mz(acc_buf, v_mv);
mfence();
for (int i = 0; i < 32; i++) {
res = fmax(res, acc_buf[i]);
}
}
return res;
}
// Use simd vector instruction to calculate max of a half array
template <>
__inline__ __device__ float max(const float *data_ptr, size_t len) {
int remain = len % 16;
int offset_last = len - remain;
float res = data_ptr[0];
for (int i = offset_last; i < len; i++) {
res = fmax(res, *(data_ptr + i));
}
mfence();
if (offset_last != 0) {
__local__ float acc_buf[16];
float32x16_t v_mv = vload_lm_float32x16_mz(data_ptr);
// for every 16 float data
for (int i = 16; i < offset_last; i += 16) {
float32x16_t v_0 = vload_lm_float32x16_mz(data_ptr + i);
v_mv = vvmax_float32x16_mz(v_mv, v_0);
}
vstore_lm_float32x16_mz(acc_buf, v_mv);
mfence();
for (int i = 0; i < 16; i++) {
res = fmax(res, acc_buf[i]);
}
}
return res;
}
} // namespace device::kunlun::kernel } // namespace device::kunlun::kernel
#endif // __INFINIOP_KUNLUN_KERNEL_COMMON_H__ #endif // __INFINIOP_KUNLUN_KERNEL_COMMON_H__
#ifndef __CAUSAL_SOFTMAX_KUNLUN_H__
#define __CAUSAL_SOFTMAX_KUNLUN_H__
#include "../causal_softmax.h"
DESCRIPTOR(kunlun)
#endif
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "causal_softmax_kunlun.h"
#include "kernel.h"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
__global__ void causalSoftmaxKernel(
Tdata *y,
const Tdata *x,
uint32_t batch,
uint32_t height,
uint32_t width,
int32_t y_stride_h,
int32_t x_stride_h) {
__shared__ Tdata x_sm[SM_SIZE / sizeof(Tdata)];
__shared__ Tdata y_sm[SM_SIZE / sizeof(Tdata)];
int row_id = cluster_id();
__global_ptr__ Tdata *y_ = y + row_id * y_stride_h;
__global_ptr__ const Tdata *x_ = x + row_id * x_stride_h;
if (core_id() == 0) {
GM2SM_ASYNC(x_, x_sm, width * sizeof(Tdata));
}
sync_cluster();
causalSoftmaxBlock<BLOCK_SIZE, Tdata, Tcompute>(y_sm, x_sm, height, width, row_id);
if (core_id() == 0) {
SM2GM_ASYNC(y_sm, y_, width * sizeof(Tdata));
}
sync_cluster();
}
namespace op::causal_softmax::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::kunlun::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
size_t batch_size, size_t seq_len, size_t total_seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_h,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_h,
kunlunStream_t stream) {
// Kunlunxin kernel dont support ptrdiff_t and size_t as parameters
uint32_t batch_size_ = static_cast<uint32_t>(batch_size);
uint32_t seq_len_ = static_cast<uint32_t>(seq_len);
uint32_t total_seq_len_ = static_cast<uint32_t>(total_seq_len);
int32_t y_stride_b_ = static_cast<int32_t>(y_stride_b);
int32_t y_stride_h_ = static_cast<int32_t>(y_stride_h);
int32_t x_stride_b_ = static_cast<int32_t>(x_stride_b);
int32_t x_stride_h_ = static_cast<int32_t>(x_stride_h);
#define LAUCH_KERNEL(Tdata, Tcompute) \
for (uint32_t i = 0; i < batch_size_; ++i) { \
causalSoftmaxKernel<BLOCK_SIZE, Tdata, Tcompute> \
<<<seq_len_, BLOCK_SIZE, stream>>>((Tdata *)y + i * y_stride_b_, (const Tdata *)x + i * x_stride_b_, \
batch_size, seq_len, total_seq_len, \
y_stride_h, x_stride_h); \
}
if (dtype == INFINI_DTYPE_F16) {
LAUCH_KERNEL(half, float);
} else if (dtype == INFINI_DTYPE_BF16) {
LAUCH_KERNEL(bfloat16_t, float);
} else if (dtype == INFINI_DTYPE_F32) {
LAUCH_KERNEL(float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
#undef LAUCH_KERNEL
}
infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream_) const {
kunlunStream_t stream = (kunlunStream_t)stream_;
CHECK_STATUS(launchKernel<64>(
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
return INFINI_STATUS_SUCCESS;
}
} // namespace op::causal_softmax::kunlun
#ifndef __CAUSAL_SOFTMAX_KUNLUN_KERNEL_H__
#define __CAUSAL_SOFTMAX_KUNLUN_KERNEL_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../reduce/kunlun/reduce_kunlun.h"
using namespace device::kunlun::kernel;
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
__device__ void causalSoftmaxBlock(
__shared_ptr__ Tdata *y,
__shared_ptr__ const Tdata *x,
size_t height,
size_t width,
int row_id) {
// Reduce max for each row and store in shared memory
__shared__ Tdata max_;
Tdata max_0 = op::common_kunlun::reduce_op::max<BLOCK_SIZE, Tdata>(x, width - height + 1 + size_t(row_id));
if (core_id() == 0) {
max_ = max_0;
}
sync_cluster();
// Elemetwise sub max for each element and apply causal softmax
for (size_t col = core_id(); col < width; col += BLOCK_SIZE) {
// row_id ↓ |<- width ->|
// 0 | * * * ... * |
// 1 | * * * ... * * |
// 2 | * * * ... * * * |
// height: 3 col_id->
if (width + size_t(row_id) >= col + height) {
if constexpr (std::is_same_v<Tdata, half>) {
y[col] = hexp(loadsm(x + col) - loadsm(&max_));
} else if constexpr (std::is_same_v<Tdata, bfloat16_t>) {
y[col] = __float2bfloat16(exp(__bfloat162float(x[col]) - __bfloat162float(max_)));
} else {
y[col] = exp(x[col] - max_);
}
} else {
y[col] = Tdata(0);
}
}
sync_cluster();
// Reduce sum for each row
__shared__ Tcompute sum_;
Tcompute sum_0 = op::common_kunlun::reduce_op::sum<BLOCK_SIZE, Tdata, Tcompute>(y, width);
if (core_id() == 0) {
sum_ = sum_0;
}
sync_cluster();
// Apply softmax
for (size_t col = core_id(); col < width; col += BLOCK_SIZE) {
if (sum_ != 0) {
y[col] = to<Tdata>(to<Tcompute>(loadsm(y + col)) / sum_);
} else {
y[col] = Tdata(0);
}
}
sync_cluster();
}
#endif
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#ifdef ENABLE_CAMBRICON_API #ifdef ENABLE_CAMBRICON_API
#include "bang/causal_softmax_bang.h" #include "bang/causal_softmax_bang.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/causal_softmax_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -50,6 +53,9 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( ...@@ -50,6 +53,9 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend) CREATE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun)
#endif #endif
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe ...@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#endif #endif
#ifdef ENABLE_CAMBRICON_API #ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang) GET(INFINI_DEVICE_CAMBRICON, bang)
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif #endif
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -115,6 +124,9 @@ __C infiniStatus_t infiniopCausalSoftmax( ...@@ -115,6 +124,9 @@ __C infiniStatus_t infiniopCausalSoftmax(
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend) CALCULATE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun)
#endif #endif
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -145,6 +157,9 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD ...@@ -145,6 +157,9 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
DESTROY(INFINI_DEVICE_ASCEND, ascend) DESTROY(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun)
#endif #endif
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
...@@ -5,14 +5,4 @@ ...@@ -5,14 +5,4 @@
DESCRIPTOR(kunlun) DESCRIPTOR(kunlun)
#define INSTANTIATE_RMSNORM_KERNEL(BLOCK_SIZE, Tcompute, Tdata, Tweight) \
template __global__ void rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight>( \
Tdata * y, \
int32_t stride_y, \
const Tdata *x, \
int32_t stride_x, \
const Tweight *w, \
uint32_t dim, \
float epsilon);
#endif #endif
...@@ -26,7 +26,7 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size ...@@ -26,7 +26,7 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size
atomicAdd(&temp_storage, ss); atomicAdd(&temp_storage, ss);
sync_cluster(); sync_cluster();
return temp_storage; return loadsm(&temp_storage);
} }
// Sum(x) on contiguous data of length count // Sum(x) on contiguous data of length count
...@@ -48,7 +48,30 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun ...@@ -48,7 +48,30 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun
atomicAdd(&temp_storage, ss); atomicAdd(&temp_storage, ss);
sync_cluster(); sync_cluster();
return temp_storage; return loadsm(&temp_storage);
}
// Max(x) on contiguous data of length count
template <unsigned int BLOCK_SIZE, typename Tdata>
__device__ inline Tdata max(__shared_ptr__ const Tdata *data_ptr, size_t count) {
Tdata max_val = loadsm(data_ptr);
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
// Tdata xi = loadsm(data_ptr + i);
Tdata xi = loadsm(data_ptr + i);
max_val = fmax(max_val, to<Tdata>(xi));
}
__shared__ Tdata temp_storage;
if (core_id() == 0) {
temp_storage = loadsm(data_ptr);
}
sync_cluster();
atomicMax(&temp_storage, max_val);
sync_cluster();
return loadsm(&temp_storage);
} }
} // namespace op::common_kunlun::reduce_op } // namespace op::common_kunlun::reduce_op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment