Unverified Commit a6ed694a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

remaning no_ to num_

parent 3f498d32
......@@ -296,8 +296,8 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(t: torch.Tensor, no_dims: int):
return t.reshape(t.shape[:-no_dims] + (-1,))
def flatten_final_dims(t: torch.Tensor, num_dims: int):
return t.reshape(t.shape[:-num_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-10):
......@@ -324,18 +324,18 @@ def one_hot(x, num_classes, dtype=torch.float32):
return x_one_hot
def batched_gather(data, inds, dim=0, no_batch_dims=0):
assert dim < 0 or dim - no_batch_dims >= 0
def batched_gather(data, inds, dim=0, num_batch_dims=0):
assert dim < 0 or dim - num_batch_dims >= 0
ranges = []
for i, s in enumerate(data.shape[:no_batch_dims]):
for i, s in enumerate(data.shape[:num_batch_dims]):
r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)
remaining_dims = [
slice(None) for _ in range(len(data.shape) - no_batch_dims)
slice(None) for _ in range(len(data.shape) - num_batch_dims)
]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
remaining_dims[dim - num_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
return data[ranges]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment