Commit 08a29c28 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/66: modified random sample test function

parent 04aa18f6
......@@ -21,7 +21,6 @@ from libinfiniop import (
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# x_shape, x_stride
((32, 512), None),
......
......@@ -22,7 +22,6 @@ from libinfiniop import (
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# voc, random_val, topp, topk, temperature
(512, 0.8, 0.8, 3, 0.5),
......@@ -59,53 +58,52 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach()
sorted_indices = torch.arange(voc)
for i in range(topk):
for j in range(i + 1, voc):
if dataNp[i] < dataNp[j]:
tmp = dataNp[i].clone().detach()
dataNp[i] = dataNp[j].clone().detach()
dataNp[j] = tmp
tmpInd = sorted_indices[i].clone().detach()
sorted_indices[i] = sorted_indices[j].clone().detach()
sorted_indices[j] = tmpInd
# sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk]
dataNp = dataNp[sorted_indices]
globalM = dataNp[0]
dataNp = (dataNp - globalM) / temperature
dataNp = torch.softmax(dataNp.float(), dim=0)
sum_s = 0
for end in range(topk):
sum_s += dataNp[end]
if sum_s >= topp:
break
if end < topk - 1:
end += 1
if topp > 0 and topk > 1:
indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach()
sorted_indices = torch.arange(voc)
for i in range(topk):
for j in range(i + 1, voc):
if dataNp[i] < dataNp[j]:
tmp = dataNp[i].clone().detach()
dataNp[i] = dataNp[j].clone().detach()
dataNp[j] = tmp
tmpInd = sorted_indices[i].clone().detach()
sorted_indices[i] = sorted_indices[j].clone().detach()
sorted_indices[j] = tmpInd
# sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk]
dataNp = dataNp[sorted_indices]
globalM = dataNp[0]
dataNp = (dataNp - globalM) / temperature
dataNp = torch.softmax(dataNp.float(), dim=0)
sum_s = 0
for end in range(topk):
sum_s += dataNp[end]
if sum_s >= topp:
break
if end < topk - 1:
end += 1
else:
end = topk
sum_s = 0
for i in range(end):
sum_s += dataNp[i]
random_val *= sum_s
sum_s = 0
for i in range(end):
sum_s += dataNp[i]
if random_val < sum_s:
return indices[i]
else:
end = topk
sum_s = 0
for i in range(end):
sum_s += dataNp[i]
random_val *= sum_s
sum_s = 0
for i in range(end):
sum_s += dataNp[i]
if random_val < sum_s:
return indices[i]
def random_sample_0(data):
return torch.argmax(data)
return torch.argmax(data)
def test(
......@@ -124,12 +122,10 @@ def test(
data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device)
if topp > 0 and topk > 1:
ans = random_sample(
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
)
else:
ans = random_sample_0(data)
ans = random_sample(
data, random_val, topp, topk, voc, temperature, torch_device
) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment