Commit 4f6d038c authored by OlivierDehaene's avatar OlivierDehaene
Browse files

fix(server): fix multinomial implem in Sampling

parent a6c18c39
...@@ -25,10 +25,10 @@ class Sampling: ...@@ -25,10 +25,10 @@ class Sampling:
def __call__(self, logits): def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1) probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q = torch.empty_like(probs).exponential_(1, generator=self.generator).div_(probs) q = torch.empty_like(probs).exponential_(1, generator=self.generator)
return probs.div_(q).argmax()
return q.argmax()
class Greedy: class Greedy:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment