Commit 08a29c28 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/66: modified random sample test function

parent 04aa18f6
...@@ -21,7 +21,6 @@ from libinfiniop import ( ...@@ -21,7 +21,6 @@ from libinfiniop import (
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# x_shape, x_stride # x_shape, x_stride
((32, 512), None), ((32, 512), None),
......
...@@ -22,7 +22,6 @@ from libinfiniop import ( ...@@ -22,7 +22,6 @@ from libinfiniop import (
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# voc, random_val, topp, topk, temperature # voc, random_val, topp, topk, temperature
(512, 0.8, 0.8, 3, 0.5), (512, 0.8, 0.8, 3, 0.5),
...@@ -59,6 +58,7 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor) ...@@ -59,6 +58,7 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
if topp > 0 and topk > 1:
indices = torch.zeros([topk], dtype=torch.int64) indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach() dataNp = data.clone().detach()
sorted_indices = torch.arange(voc) sorted_indices = torch.arange(voc)
...@@ -102,9 +102,7 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): ...@@ -102,9 +102,7 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
sum_s += dataNp[i] sum_s += dataNp[i]
if random_val < sum_s: if random_val < sum_s:
return indices[i] return indices[i]
else:
def random_sample_0(data):
return torch.argmax(data) return torch.argmax(data)
...@@ -124,12 +122,10 @@ def test( ...@@ -124,12 +122,10 @@ def test(
data = torch.arange(voc).float() * 0.0001 data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc) _perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device) data = data[_perm].to(x_dtype).to(torch_device)
if topp > 0 and topk > 1:
ans = random_sample( ans = random_sample(
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu" data, random_val, topp, topk, voc, temperature, torch_device
) ) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
else:
ans = random_sample_0(data)
indices = torch.zeros([1], dtype=torch.int64).to(torch_device) indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment