Commit d741ee7d authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/342: success random_sample all

parent c5bc6628
......@@ -162,7 +162,7 @@ __device__ void findTopOne_local(
result[0] = indices_a;
}
template <unsigned int BLOCK_SIZE, typename Tval, typename Tcompute, typename Tidx>
template <unsigned int CLUSTER_SIZE, unsigned int BLOCK_SIZE, typename Tval, typename Tcompute, typename Tidx>
__global__ void random_sampleKernel(Tidx *result,
const Tval *probs,
float random_val,
......@@ -263,7 +263,7 @@ __global__ void random_sampleKernel(Tidx *result,
findTopk(values_global, indices_global, nthreads * topk, topk);
}
}
sync_cluster();
//上面这部分是计算topk,数据分别存储在values_global,indices_global里面
__global_ptr__ Tval *values_global_ = values_global;
__shared__ Tval max_value;
......@@ -290,7 +290,8 @@ __global__ void random_sampleKernel(Tidx *result,
if(cid == 0){
if constexpr (std::is_same_v<Tcompute, half>) {
sum_ = __float2half(0.0f);
} else if constexpr (std::is_same_v<Tcompute, bfloat16_t>) {
}
else if constexpr (std::is_same_v<Tcompute, bfloat16_t>) {
sum_ = __float2bfloat16(0.0f);
}
else if constexpr (std::is_same_v<Tcompute, float>) {
......@@ -302,14 +303,15 @@ __global__ void random_sampleKernel(Tidx *result,
for (int r = 0; r < sm_repeat; r++) {
if (cid == 0) {
GM2SM_ASYNC(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval));
GM2SM(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval));
}
sync_cluster();
for (int index = cid; index < sm_size; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to<half>(temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
}
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature));
}
else if constexpr (std::is_same_v<Tval, float>) {
......@@ -332,13 +334,14 @@ __global__ void random_sampleKernel(Tidx *result,
if (sm_step) {
if (cid == 0) {
GM2SM_ASYNC(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval));
GM2SM(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval));
}
sync_cluster();
for (int index = cid; index < sm_step; index += BLOCK_SIZE) {
if constexpr (std::is_same_v<Tval, half>) {
y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to<half>(temperature));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
}
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature));
}
else if constexpr (std::is_same_v<Tval, float>) {
......@@ -358,17 +361,18 @@ __global__ void random_sampleKernel(Tidx *result,
__global_ptr__ Tcompute *sum_global_ = sum_global;
if (core_id() == 0) {
SM2GM_ASYNC(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute));
SM2GM(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute));
}
sync_cluster();
__shared__ Tcompute all_sum;
__shared__ Tcompute z_sm[CLUSTER_SIZE];
if(cid == 0){
GM2SM_ASYNC(sum_global_, x_sm, cluster_num() * sizeof(Tcompute));
GM2SM(sum_global_, z_sm, cluster_num() * sizeof(Tcompute));
}
sync_cluster();
Tcompute all_sum_0 = sum<BLOCK_SIZE, Tcompute, Tcompute>(x_sm, cluster_num());
Tcompute all_sum_0 = sum<BLOCK_SIZE, Tcompute, Tcompute>(z_sm, cluster_num());
if (cid == 0) {
all_sum = all_sum_0;
}
......@@ -377,19 +381,19 @@ __global__ void random_sampleKernel(Tidx *result,
if (thread_id == 0) {
int end = topk;
float cumsum = 0.0f;
for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){
int read_len = (r < topk / buf_size ? buf_size : topk % buf_size);
GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval));
for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(loadsm(&all_sum));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature) / to<float>(loadsm(&all_sum));
cumsum += exp((values_local[index] - max_value) / temperature) / to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature) / to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature) / to<float>(loadsm(&all_sum));
cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature) / to<float>(loadsm(&all_sum));
}
if (cumsum >= topp) {
end = r * buf_size + index + 1;
......@@ -405,11 +409,12 @@ __global__ void random_sampleKernel(Tidx *result,
for (int index = 0; index < read_len; index++) {
if constexpr (std::is_same_v<Tval, float>) {
cumsum += exp((values_local[index] - max_value) / temperature)/ to<float>(loadsm(&all_sum));
} else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature) / to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, bfloat16_t>) {
cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature) / to<float>(loadsm(&all_sum));
}
else if constexpr (std::is_same_v<Tval, half>) {
cumsum += exp((to<float>(values_local[index]) - to<float>(loadsm(&max_value))) / temperature)/ to<float>(loadsm(&all_sum));
cumsum += exp(to<float>(values_local[index]) - to<float>(loadsm(&max_value))/ temperature) / to<float>(loadsm(&all_sum));
}
if (random_val < cumsum) {
result[0] = indices_global[r * buf_size + index];
......@@ -505,12 +510,13 @@ void random_sampleFunction(void *workspace,
Tval *values = (Tval *)workspace_value;
xpu_memcpy(values, (Tval *)probs, n * sizeof(Tval), XPU_DEVICE_TO_DEVICE);
Tval *values_global = values + n;
Tval *sum_global = values_global + cluster_num * core_num * topk_;
char *workspace_index = workspace_value + (n + cluster_num * core_num * topk_ + cluster_num) * sizeof(Tval);
char *workspace_sum = workspace_value + (n + cluster_num * core_num * topk_) * sizeof(Tval);
float *sum_global = (float *)workspace_sum;
char *workspace_index = workspace_sum + cluster_num * sizeof(float);
Tidx *indices = (Tidx *)workspace_index;
Tidx *indices_global = indices + n;
if (dosample){
random_sampleKernel<core_num, Tval, Tval, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result,
random_sampleKernel<cluster_num, core_num, Tval, float, Tidx><<<cluster_num, core_num, stream>>>((Tidx *)result,
(Tval *)probs,
random_val,
topp,
......@@ -560,10 +566,12 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result);
auto info = result.take();
// size_t workspace_size = 3 * probs_desc->numel() * infiniSizeOf(probs_desc->dtype()) + probs_desc->numel() * infiniSizeOf(infiniDtype_t::INFINI_DTYPE_I32);
int cluster_num = 256;
int cluster_num = 8;
int core_num = 64;
size_t workspace_size = (probs_desc->numel() + cluster_num * core_num * probs_desc->numel() + cluster_num) * infiniSizeOf(probs_desc->dtype()) + (probs_desc->numel() + cluster_num * core_num * probs_desc->numel()) * infiniSizeOf(result_desc->dtype());
int n = probs_desc->numel();
int topk = 50;//必须想办法控制workspace大小,如果topk太大会导致无法申请进而结果报错
size_t workspace_size = (n + cluster_num * core_num * topk) * (infiniSizeOf(probs_desc->dtype()) + infiniSizeOf(result_desc->dtype())) + cluster_num * sizeof(float);
*desc_ptr = new Descriptor(
info,
workspace_size,
......
......@@ -54,8 +54,7 @@ NUM_ITERATIONS = 1000
def random_sample(data, random_val, topp, topk, voc, temperature):
if topp > 0 and topk > 1:
sorted_vals, sorted_indices = torch.sort(data, descending=True)
print(sorted_vals[:topk])
print(sorted_indices[:topk])
scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
try:
probs = torch.softmax(scaled_vals, dim=0)
......@@ -158,7 +157,7 @@ def test(
if sync is not None:
sync()
print(indices.actual_tensor(), ans)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug_all(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment