Unverified Commit c24a52ea authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

issue/278 - Cambricon Random Sample

parent f19238db
...@@ -7,6 +7,21 @@ ...@@ -7,6 +7,21 @@
namespace device::bang::kernel { namespace device::bang::kernel {
template <typename T>
__mlu_device__ float to_float(const T &v) {
return static_cast<float>(v);
}
template <typename T>
__mlu_device__ bfloat16_t to_bfloat16(const T &v) {
return static_cast<bfloat16_t>(v);
}
template <typename T>
__mlu_device__ half to_half(const T &v) {
return static_cast<half>(v);
}
/** /**
* @brief Converts a flattened index to a reduced offset considering broadcasting. * @brief Converts a flattened index to a reduced offset considering broadcasting.
* *
......
#ifndef __RANDOM_SAMPLE_BANG_H__
#define __RANDOM_SAMPLE_BANG_H__
#include "../random_sample.h"
DESCRIPTOR(bang)
#endif // __RANDOM_SAMPLE_BANG_H__
#include "../../../devices/bang/bang_handle.h"
#include "../info.h"
#include "random_sample_bang.h"
#include "random_sample_kernel.mlu"
namespace op::random_sample::bang {
struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t result_desc,
infiniopTensorDescriptor_t probs_desc) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto result = RandomSampleInfo::create(result_desc, probs_desc);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size;
#define CASE_P(CASE, Tidx, Tval) \
case CASE: { \
auto workspace_result = calculateWorkspace<Tidx, Tval>(info.n); \
CHECK_RESULT(workspace_result); \
workspace_size = workspace_result.take(); \
} break
#define CASE_I(CASE, Tidx) \
case CASE: \
switch (info.dt_p) { \
CASE_P(INFINI_DTYPE_F16, Tidx, half); \
CASE_P(INFINI_DTYPE_BF16, Tidx, bfloat16_t); \
CASE_P(INFINI_DTYPE_F32, Tidx, float); \
default: \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} \
break
switch (info.dt_i) {
CASE_I(INFINI_DTYPE_I8, int8_t);
CASE_I(INFINI_DTYPE_I16, int16_t);
CASE_I(INFINI_DTYPE_I32, int32_t);
CASE_I(INFINI_DTYPE_I64, int64_t);
CASE_I(INFINI_DTYPE_U8, uint8_t);
CASE_I(INFINI_DTYPE_U16, uint16_t);
CASE_I(INFINI_DTYPE_U32, uint32_t);
CASE_I(INFINI_DTYPE_U64, uint64_t);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef CASE_I
#undef CASE_P
*desc_ptr = new Descriptor(
info,
workspace_size,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
size_t Descriptor::minWorkspaceSize() const {
return _min_workspace_size;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
float random_val,
float topp,
int topk,
float temperature,
void *stream) const {
if (workspace_size < _min_workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
Calculate::calculate<Algo>(
Algo{}, _info, workspace, workspace_size,
result, probs,
random_val, topp, topk, temperature,
stream);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::random_sample::bang
This diff is collapsed.
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/random_sample_nvidia.cuh" #include "nvidia/random_sample_nvidia.cuh"
#endif #endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/random_sample_bang.h"
#endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "metax/random_sample_metax.h" #include "metax/random_sample_metax.h"
#endif #endif
...@@ -41,6 +44,9 @@ infiniopCreateRandomSampleDescriptor( ...@@ -41,6 +44,9 @@ infiniopCreateRandomSampleDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia); CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
...@@ -77,6 +83,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize( ...@@ -77,6 +83,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia); GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax); GET(INFINI_DEVICE_METAX, metax);
#endif #endif
...@@ -123,6 +132,9 @@ __C infiniStatus_t infiniopRandomSample( ...@@ -123,6 +132,9 @@ __C infiniStatus_t infiniopRandomSample(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax); CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
...@@ -156,6 +168,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor( ...@@ -156,6 +168,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia); DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
......
...@@ -56,7 +56,14 @@ def random_sample(data, random_val, topp, topk, voc, temperature): ...@@ -56,7 +56,14 @@ def random_sample(data, random_val, topp, topk, voc, temperature):
sorted_vals, sorted_indices = torch.sort(data, descending=True) sorted_vals, sorted_indices = torch.sort(data, descending=True)
scaled_vals = (sorted_vals - sorted_vals[0]) / temperature scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
try:
probs = torch.softmax(scaled_vals, dim=0)
except RuntimeError as e:
if "not implemented for 'Half'" in str(e):
scaled_vals = scaled_vals.to(torch.float32)
probs = torch.softmax(scaled_vals, dim=0) probs = torch.softmax(scaled_vals, dim=0)
else:
raise
cum_probs = torch.cumsum(probs, dim=0) cum_probs = torch.cumsum(probs, dim=0)
k_index = min(topk, voc) - 1 k_index = min(topk, voc) - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment