Unverified Commit a6ed694a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

remaning no_ to num_

parent 3f498d32
...@@ -296,8 +296,8 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]): ...@@ -296,8 +296,8 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
return tensor.permute(first_inds + [zero_index + i for i in inds]) return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(t: torch.Tensor, no_dims: int): def flatten_final_dims(t: torch.Tensor, num_dims: int):
return t.reshape(t.shape[:-no_dims] + (-1,)) return t.reshape(t.shape[:-num_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-10): def masked_mean(mask, value, dim, eps=1e-10):
...@@ -324,18 +324,18 @@ def one_hot(x, num_classes, dtype=torch.float32): ...@@ -324,18 +324,18 @@ def one_hot(x, num_classes, dtype=torch.float32):
return x_one_hot return x_one_hot
def batched_gather(data, inds, dim=0, no_batch_dims=0): def batched_gather(data, inds, dim=0, num_batch_dims=0):
assert dim < 0 or dim - no_batch_dims >= 0 assert dim < 0 or dim - num_batch_dims >= 0
ranges = [] ranges = []
for i, s in enumerate(data.shape[:no_batch_dims]): for i, s in enumerate(data.shape[:num_batch_dims]):
r = torch.arange(s) r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r) ranges.append(r)
remaining_dims = [ remaining_dims = [
slice(None) for _ in range(len(data.shape) - no_batch_dims) slice(None) for _ in range(len(data.shape) - num_batch_dims)
] ]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds remaining_dims[dim - num_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims) ranges.extend(remaining_dims)
return data[ranges] return data[ranges]
...@@ -398,4 +398,4 @@ def set_jit_fusion_options(): ...@@ -398,4 +398,4 @@ def set_jit_fusion_options():
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_override_can_fuse_on_gpu(True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment