Unverified Commit 326c84c4 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Compiling rope while preserving true on policy (#12161)

parent 8da608cc
...@@ -125,8 +125,13 @@ class RotaryEmbedding(CustomOp): ...@@ -125,8 +125,13 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache: torch.Tensor self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
self._apply_rotary_emb_wrapped = _apply_rotary_emb
if get_global_server_args().rl_on_policy_target == "fsdp": if get_global_server_args().rl_on_policy_target == "fsdp":
self._forward_method = self.forward_native self._forward_method = self.forward_native
self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)(
self._apply_rotary_emb_wrapped
)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency.""" """Compute the inverse frequency."""
...@@ -185,14 +190,16 @@ class RotaryEmbedding(CustomOp): ...@@ -185,14 +190,16 @@ class RotaryEmbedding(CustomOp):
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim] query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :] query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query_rot = self._apply_rotary_emb_wrapped(
query_rot, cos, sin, self.is_neox_style
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key_rot = self._apply_rotary_emb_wrapped(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment