"vscode:/vscode.git/clone" did not exist on "ba552dd0274206628d59ea349171a0d4d8632c3a"
Unverified Commit 61f79b29 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[gptj] support older pytorch version (#22325)



* [gptj] support older pytorch version

* contributor

* contributor

* make copies

---------
Co-authored-by: default avatarMichael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: default avatarNick Hill <nickhill@us.ibm.com>
parent 80e3b363
...@@ -55,7 +55,7 @@ CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -55,7 +55,7 @@ CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float() sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
return torch.concat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two # Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
......
...@@ -18,6 +18,7 @@ import warnings ...@@ -18,6 +18,7 @@ import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torch.fx
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -57,7 +58,7 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -57,7 +58,7 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float() sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
return torch.concat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
@torch.fx.wrap @torch.fx.wrap
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment