Commit d741ee7d authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/342: success random_sample all

parent c5bc6628
...@@ -162,7 +162,7 @@ __device__ void findTopOne_local( ...@@ -162,7 +162,7 @@ __device__ void findTopOne_local(
result[0] = indices_a; result[0] = indices_a;
} }
template <unsigned int BLOCK_SIZE, typename Tval, typename Tcompute, typename Tidx> template <unsigned int CLUSTER_SIZE, unsigned int BLOCK_SIZE, typename Tval, typename Tcompute, typename Tidx>
__global__ void random_sampleKernel(Tidx *result, __global__ void random_sampleKernel(Tidx *result,
const Tval *probs, const Tval *probs,
float random_val, float random_val,
...@@ -263,7 +263,7 @@ __global__ void random_sampleKernel(Tidx *result, ...@@ -263,7 +263,7 @@ __global__ void random_sampleKernel(Tidx *result,
findTopk(values_global, indices_global, nthreads * topk, topk); findTopk(values_global, indices_global, nthreads * topk, topk);
} }
} }
sync_cluster();
//上面这部分是计算topk,数据分别存储在values_global,indices_global里面 //上面这部分是计算topk,数据分别存储在values_global,indices_global里面
__global_ptr__ Tval *values_global_ = values_global; __global_ptr__ Tval *values_global_ = values_global;
__shared__ Tval max_value; __shared__ Tval max_value;
...@@ -290,7 +290,8 @@ __global__ void random_sampleKernel(Tidx *result, ...@@ -290,7 +290,8 @@ __global__ void random_sampleKernel(Tidx *result,
if(cid == 0){ if(cid == 0){
if constexpr (std::is_same_v<Tcompute, half>) { if constexpr (std::is_same_v<Tcompute, half>) {
sum_ = __float2half(0.0f); sum_ = __float2half(0.0f);
} else if constexpr (std::is_same_v<Tcompute, bfloat16_t>) { }
else if constexpr (std::is_same_v<Tcompute, bfloat16_t>) {
sum_ = __float2bfloat16(0.0f); sum_ = __float2bfloat16(0.0f);
} }
else if constexpr (std::is_same_v<Tcompute, float>) { else if constexpr (std::is_same_v<Tcompute, float>) {
...@@ -302,14 +303,15 @@ __global__ void random_sampleKernel(Tidx *result, ...@@ -302,14 +303,15 @@ __global__ void random_sampleKernel(Tidx *result,
for (int r = 0; r < sm_repeat; r++) { for (int r = 0; r < sm_repeat; r++) {
if (cid == 0) { if (cid == 0) {
GM2SM_ASYNC(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval)); GM2SM(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval));
} }
sync_cluster(); sync_cluster();
for (int index = cid; index < sm_size; index += BLOCK_SIZE) { for (int index = cid; index < sm_size; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) { if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to<half>(temperature)); y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to<half>(temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) { }
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature)); y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature));
} }
else if constexpr (std::is_same_v<Tval, float>) { else if constexpr (std::is_same_v<Tval, float>) {
...@@ -332,13 +334,14 @@ __global__ void random_sampleKernel(Tidx *result, ...@@ -332,13 +334,14 @@ __global__ void random_sampleKernel(Tidx *result,
if (sm_step) { if (sm_step) {
if (cid == 0) { if (cid == 0) {
GM2SM_ASYNC(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval)); GM2SM(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval));
} }
sync_cluster(); sync_cluster();
for (int index = cid; index < sm_step; index += BLOCK_SIZE) { for (int index = cid; index < sm_step; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) { if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to<half>(temperature)); y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to<half>(temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) { }
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature)); y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature));
} }
else if constexpr (std::is_same_v<Tval, float>) { else if constexpr (std::is_same_v<Tval, float>) {
...@@ -358,17 +361,18 @@ __global__ void random_sampleKernel(Tidx *result, ...@@ -358,17 +361,18 @@ __global__ void random_sampleKernel(Tidx *result,
__global_ptr__ Tcompute *sum_global_ = sum_global; __global_ptr__ Tcompute *sum_global_ = sum_global;
if (core_id() == 0) { if (core_id() == 0) {
SM2GM_ASYNC(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute)); SM2GM(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute));
} }
sync_cluster(); sync_cluster();
__shared__ Tcompute all_sum; __shared__ Tcompute all_sum;
__shared__ Tcompute z_sm[CLUSTER_SIZE];
if(cid == 0){ if(cid == 0){
GM2SM_ASYNC(sum_global_, x_sm, cluster_num() * sizeof(Tcompute)); GM2SM(sum_global_, z_sm, cluster_num() * sizeof(Tcompute));
} }
sync_cluster(); sync_cluster();
Tcompute all_sum_0 = sum<BLOCK_SIZE, Tcompute, Tcompute>(x_sm, cluster_num()); Tcompute all_sum_0 = sum<BLOCK_SIZE, Tcompute, Tcompute>(z_sm, cluster_num());
if (cid == 0) { if (cid == 0) {
all_sum = all_sum_0; all_sum = all_sum_0;
} }
...@@ -377,19 +381,19 @@ __global__ void random_sampleKernel(Tidx *result, ...@@ -377,19 +381,19 @@ __global__ void random_sampleKernel(Tidx *result,
if (thread_id == 0) { if (thread_id == 0) {
int end = topk; int end = topk;
float cumsum = 0.0f; float cumsum = 0.0f;
for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){ for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){
int read_len = (r < topk / buf_size ? buf_size : topk % buf_size); int read_len = (r < topk / buf_size ? buf_size : topk % buf_size);
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
for (int index = 0; index < read_len; index++) { for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) { if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(loadsm(&all_sum)); cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(loadsm(&all_sum));
}
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) { else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature) / to<float>(loadsm(&all_sum)); cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature) / to<float>(loadsm(&all_sum));
} }
else if constexpr (std::is_same_v<Tval, half>) { else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature) / to<float>(loadsm(&all_sum)); cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature) / to<float>(loadsm(&all_sum));
} }
if (cumsum >= topp) { if (cumsum >= topp) {
end = r * buf_size + index + 1; end = r * buf_size + index + 1;
...@@ -405,11 +409,12 @@ __global__ void random_sampleKernel(Tidx *result, ...@@ -405,11 +409,12 @@ __global__ void random_sampleKernel(Tidx *result,
for (int index = 0; index < read_len; index++) { for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) { if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature)/ to<float>(loadsm(&all_sum)); cumsum += exp((values_local[index] - max_value) / temperature)/ to<float>(loadsm(&all_sum));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) { }
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature) / to<float>(loadsm(&all_sum)); else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature) / to<float>(loadsm(&all_sum));
} }
else if constexpr (std::is_same_v<Tval, half>) { else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature)/ to<float>(loadsm(&all_sum)); cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature) / to<float>(loadsm(&all_sum));
} }
if (random_val < cumsum) { if (random_val < cumsum) {
result[0] = indices_global[r * buf_size + index]; result[0] = indices_global[r * buf_size + index];
...@@ -505,12 +510,13 @@ void random_sampleFunction(void *workspace, ...@@ -505,12 +510,13 @@ void random_sampleFunction(void *workspace,
Tval *values = (Tval *)workspace_value; Tval *values = (Tval *)workspace_value;
xpu_memcpy(values, (Tval *)probs, n * sizeof(Tval), XPU_DEVICE_TO_DEVICE); xpu_memcpy(values, (Tval *)probs, n * sizeof(Tval), XPU_DEVICE_TO_DEVICE);
Tval *values_global = values + n; Tval *values_global = values + n;
Tval *sum_global = values_global + cluster_num * core_num * topk_; char *workspace_sum = workspace_value + (n + cluster_num * core_num * topk_) * sizeof(Tval);
char *workspace_index = workspace_value + (n + cluster_num * core_num * topk_ + cluster_num) * sizeof(Tval); float *sum_global = (float *)workspace_sum;
char *workspace_index = workspace_sum + cluster_num * sizeof(float);
Tidx *indices = (Tidx *)workspace_index; Tidx *indices = (Tidx *)workspace_index;
Tidx *indices_global = indices + n; Tidx *indices_global = indices + n;
if (dosample){ if (dosample){
random_sampleKernel<core_num, Tval, Tval, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result, random_sampleKernel<cluster_num, core_num, Tval, float, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result,
(Tval *)probs, (Tval *)probs,
random_val, random_val,
topp, topp,
...@@ -560,10 +566,12 @@ infiniStatus_t Descriptor::create( ...@@ -560,10 +566,12 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result); CHECK_RESULT(result);
auto info = result.take(); auto info = result.take();
// size_t workspace_size = 3 * probs_desc->numel() * infiniSizeOf(probs_desc->dtype()) + probs_desc->numel() * infiniSizeOf(infiniDtype_t::INFINI_DTYPE_I32);
int cluster_num = 256; int cluster_num = 8;
int core_num = 64; int core_num = 64;
size_t workspace_size = (probs_desc->numel() + cluster_num * core_num * probs_desc->numel() + cluster_num) * infiniSizeOf(probs_desc->dtype()) + (probs_desc->numel() + cluster_num * core_num * probs_desc->numel()) * infiniSizeOf(result_desc->dtype()); int n = probs_desc->numel();
int topk = 50;//必须想办法控制workspace大小,如果topk太大会导致无法申请进而结果报错
size_t workspace_size = (n + cluster_num * core_num * topk) * (infiniSizeOf(probs_desc->dtype()) + infiniSizeOf(result_desc->dtype())) + cluster_num * sizeof(float);
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
info, info,
workspace_size, workspace_size,
......
...@@ -54,8 +54,7 @@ NUM_ITERATIONS = 1000 ...@@ -54,8 +54,7 @@ NUM_ITERATIONS = 1000
def random_sample(data, random_val, topp, topk, voc, temperature): def random_sample(data, random_val, topp, topk, voc, temperature):
if topp > 0 and topk > 1: if topp > 0 and topk > 1:
sorted_vals, sorted_indices = torch.sort(data, descending=True) sorted_vals, sorted_indices = torch.sort(data, descending=True)
print(sorted_vals[:topk])
print(sorted_indices[:topk])
scaled_vals = (sorted_vals - sorted_vals[0]) / temperature scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
try: try:
probs = torch.softmax(scaled_vals, dim=0) probs = torch.softmax(scaled_vals, dim=0)
...@@ -158,7 +157,7 @@ def test( ...@@ -158,7 +157,7 @@ def test(
if sync is not None: if sync is not None:
sync() sync()
print(indices.actual_tensor(), ans)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug_all( debug_all(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment