Commit ed0886cb authored by Rebecca Chen's avatar Rebecca Chen Committed by Copybara-Service
Browse files

Silence some pytype errors.

PiperOrigin-RevId: 436161190
Change-Id: Ia10df42f5f3f638a7a5b7b804c593a1e28a4ef44
parent b85ffe10
......@@ -65,7 +65,7 @@ class Rigid3Array:
"""Return identity Rigid3Array of given shape."""
return cls(
rotation_matrix.Rot3Array.identity(shape, dtype=dtype),
vector.Vec3Array.zeros(shape, dtype=dtype))
vector.Vec3Array.zeros(shape, dtype=dtype)) # pytype: disable=wrong-arg-count # trace-all-classes
def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'."""
......@@ -80,7 +80,7 @@ class Rigid3Array:
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(array[..., :3])
vec = vector.Vec3Array.from_array(array[..., -1])
return cls(rot, vec)
return cls(rot, vec) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array:
......@@ -94,7 +94,7 @@ class Rigid3Array:
)
translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3])
return cls(rotation, translation)
return cls(rotation, translation) # pytype: disable=wrong-arg-count # trace-all-classes
def __getstate__(self):
return (VERSION, (self.rotation, self.translation))
......
......@@ -73,7 +73,7 @@ class Rot3Array:
"""Returns identity of given shape."""
ones = jnp.ones(shape, dtype=dtype)
zeros = jnp.zeros(shape, dtype=dtype)
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_two_vectors(cls, e0: vector.Vec3Array,
......@@ -96,7 +96,7 @@ class Rot3Array:
e1 = (e1 - c * e0).normalized()
# Compute e2 as cross product of e0 and e1.
e2 = e0.cross(e1)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def from_array(cls, array: jnp.ndarray) -> Rot3Array:
......@@ -137,7 +137,7 @@ class Rot3Array:
zx = 2 * (x * z - w * y)
zy = 2 * (y * z + w * x)
zz = 1 - 2 * (jnp.square(x) + jnp.square(y))
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod
def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array:
......
......@@ -104,7 +104,7 @@ class Vec3Array:
"""Return Vec3Array corresponding to zeros of given shape."""
return cls(
jnp.zeros(shape, dtype), jnp.zeros(shape, dtype),
jnp.zeros(shape, dtype))
jnp.zeros(shape, dtype)) # pytype: disable=wrong-arg-count # trace-all-classes
def to_array(self) -> jnp.ndarray:
return jnp.stack([self.x, self.y, self.z], axis=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment