Commit ed0886cb authored by Rebecca Chen's avatar Rebecca Chen Committed by Copybara-Service
Browse files

Silence some pytype errors.

PiperOrigin-RevId: 436161190
Change-Id: Ia10df42f5f3f638a7a5b7b804c593a1e28a4ef44
parent b85ffe10
...@@ -65,7 +65,7 @@ class Rigid3Array: ...@@ -65,7 +65,7 @@ class Rigid3Array:
"""Return identity Rigid3Array of given shape.""" """Return identity Rigid3Array of given shape."""
return cls( return cls(
rotation_matrix.Rot3Array.identity(shape, dtype=dtype), rotation_matrix.Rot3Array.identity(shape, dtype=dtype),
vector.Vec3Array.zeros(shape, dtype=dtype)) vector.Vec3Array.zeros(shape, dtype=dtype)) # pytype: disable=wrong-arg-count # trace-all-classes
def scale_translation(self, factor: Float) -> Rigid3Array: def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'.""" """Scale translation in Rigid3Array by 'factor'."""
...@@ -80,7 +80,7 @@ class Rigid3Array: ...@@ -80,7 +80,7 @@ class Rigid3Array:
def from_array(cls, array): def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(array[..., :3]) rot = rotation_matrix.Rot3Array.from_array(array[..., :3])
vec = vector.Vec3Array.from_array(array[..., -1]) vec = vector.Vec3Array.from_array(array[..., -1])
return cls(rot, vec) return cls(rot, vec) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod @classmethod
def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array: def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array:
...@@ -94,7 +94,7 @@ class Rigid3Array: ...@@ -94,7 +94,7 @@ class Rigid3Array:
) )
translation = vector.Vec3Array( translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]) array[..., 0, 3], array[..., 1, 3], array[..., 2, 3])
return cls(rotation, translation) return cls(rotation, translation) # pytype: disable=wrong-arg-count # trace-all-classes
def __getstate__(self): def __getstate__(self):
return (VERSION, (self.rotation, self.translation)) return (VERSION, (self.rotation, self.translation))
......
...@@ -73,7 +73,7 @@ class Rot3Array: ...@@ -73,7 +73,7 @@ class Rot3Array:
"""Returns identity of given shape.""" """Returns identity of given shape."""
ones = jnp.ones(shape, dtype=dtype) ones = jnp.ones(shape, dtype=dtype)
zeros = jnp.zeros(shape, dtype=dtype) zeros = jnp.zeros(shape, dtype=dtype)
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod @classmethod
def from_two_vectors(cls, e0: vector.Vec3Array, def from_two_vectors(cls, e0: vector.Vec3Array,
...@@ -96,7 +96,7 @@ class Rot3Array: ...@@ -96,7 +96,7 @@ class Rot3Array:
e1 = (e1 - c * e0).normalized() e1 = (e1 - c * e0).normalized()
# Compute e2 as cross product of e0 and e1. # Compute e2 as cross product of e0 and e1.
e2 = e0.cross(e1) e2 = e0.cross(e1)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod @classmethod
def from_array(cls, array: jnp.ndarray) -> Rot3Array: def from_array(cls, array: jnp.ndarray) -> Rot3Array:
...@@ -137,7 +137,7 @@ class Rot3Array: ...@@ -137,7 +137,7 @@ class Rot3Array:
zx = 2 * (x * z - w * y) zx = 2 * (x * z - w * y)
zy = 2 * (y * z + w * x) zy = 2 * (y * z + w * x)
zz = 1 - 2 * (jnp.square(x) + jnp.square(y)) zz = 1 - 2 * (jnp.square(x) + jnp.square(y))
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) # pytype: disable=wrong-arg-count # trace-all-classes
@classmethod @classmethod
def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array: def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array:
......
...@@ -104,7 +104,7 @@ class Vec3Array: ...@@ -104,7 +104,7 @@ class Vec3Array:
"""Return Vec3Array corresponding to zeros of given shape.""" """Return Vec3Array corresponding to zeros of given shape."""
return cls( return cls(
jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), jnp.zeros(shape, dtype),
jnp.zeros(shape, dtype)) jnp.zeros(shape, dtype)) # pytype: disable=wrong-arg-count # trace-all-classes
def to_array(self) -> jnp.ndarray: def to_array(self) -> jnp.ndarray:
return jnp.stack([self.x, self.y, self.z], axis=-1) return jnp.stack([self.x, self.y, self.z], axis=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment