Commit 31c962dc authored by zhe chen's avatar zhe chen
Browse files

Update huggingface model code

parent 773e90c1
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel): ...@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor): def forward(self, pixel_values):
return self.model.forward_features(tensor) return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel): class InternImageModelForImageClassification(PreTrainedModel):
...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel): ...@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G remove_center=config.remove_center, # for InternImage-H/G
) )
def forward(self, tensor, labels=None): def forward(self, pixel_values, labels=None):
outputs = self.model.forward(tensor) outputs = self.model.forward(pixel_values)
if labels is not None: if labels is not None:
logits = outputs['logits'] logits = outputs['logits']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment