Commit 08a29c28 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/66: modified random sample test function

parent 04aa18f6
...@@ -21,7 +21,6 @@ from libinfiniop import ( ...@@ -21,7 +21,6 @@ from libinfiniop import (
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# x_shape, x_stride # x_shape, x_stride
((32, 512), None), ((32, 512), None),
......
...@@ -22,7 +22,6 @@ from libinfiniop import ( ...@@ -22,7 +22,6 @@ from libinfiniop import (
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# voc, random_val, topp, topk, temperature # voc, random_val, topp, topk, temperature
(512, 0.8, 0.8, 3, 0.5), (512, 0.8, 0.8, 3, 0.5),
...@@ -59,53 +58,52 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor) ...@@ -59,53 +58,52 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
indices = torch.zeros([topk], dtype=torch.int64) if topp > 0 and topk > 1:
dataNp = data.clone().detach() indices = torch.zeros([topk], dtype=torch.int64)
sorted_indices = torch.arange(voc) dataNp = data.clone().detach()
sorted_indices = torch.arange(voc)
for i in range(topk):
for j in range(i + 1, voc): for i in range(topk):
if dataNp[i] < dataNp[j]: for j in range(i + 1, voc):
tmp = dataNp[i].clone().detach() if dataNp[i] < dataNp[j]:
dataNp[i] = dataNp[j].clone().detach() tmp = dataNp[i].clone().detach()
dataNp[j] = tmp dataNp[i] = dataNp[j].clone().detach()
dataNp[j] = tmp
tmpInd = sorted_indices[i].clone().detach()
sorted_indices[i] = sorted_indices[j].clone().detach() tmpInd = sorted_indices[i].clone().detach()
sorted_indices[j] = tmpInd sorted_indices[i] = sorted_indices[j].clone().detach()
sorted_indices[j] = tmpInd
# sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk] # sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk]
dataNp = dataNp[sorted_indices]
dataNp = dataNp[sorted_indices]
globalM = dataNp[0]
dataNp = (dataNp - globalM) / temperature globalM = dataNp[0]
dataNp = torch.softmax(dataNp.float(), dim=0) dataNp = (dataNp - globalM) / temperature
sum_s = 0 dataNp = torch.softmax(dataNp.float(), dim=0)
for end in range(topk): sum_s = 0
sum_s += dataNp[end] for end in range(topk):
if sum_s >= topp: sum_s += dataNp[end]
break if sum_s >= topp:
if end < topk - 1: break
end += 1 if end < topk - 1:
end += 1
else:
end = topk
sum_s = 0
for i in range(end):
sum_s += dataNp[i]
random_val *= sum_s
sum_s = 0
for i in range(end):
sum_s += dataNp[i]
if random_val < sum_s:
return indices[i]
else: else:
end = topk return torch.argmax(data)
sum_s = 0
for i in range(end):
sum_s += dataNp[i]
random_val *= sum_s
sum_s = 0
for i in range(end):
sum_s += dataNp[i]
if random_val < sum_s:
return indices[i]
def random_sample_0(data):
return torch.argmax(data)
def test( def test(
...@@ -124,12 +122,10 @@ def test( ...@@ -124,12 +122,10 @@ def test(
data = torch.arange(voc).float() * 0.0001 data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc) _perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device) data = data[_perm].to(x_dtype).to(torch_device)
if topp > 0 and topk > 1:
ans = random_sample( ans = random_sample(
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu" data, random_val, topp, topk, voc, temperature, torch_device
) ) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
else:
ans = random_sample_0(data)
indices = torch.zeros([1], dtype=torch.int64).to(torch_device) indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment