"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "df99f8c5a1c54d64fb013b43107011390c3be0d5"
Unverified Commit 4f0337a0 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Time Series Transformer] Add doc tests (#19607)



* Add doc tests

* Make it more consistent
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent c937f0b9
...@@ -1614,33 +1614,30 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): ...@@ -1614,33 +1614,30 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
Examples: Examples:
```python ```python
>>> from transformers import TimeSeriesTransformerModel >>> from huggingface_hub import hf_hub_download
>>> import torch >>> import torch
>>> from transformers import TimeSeriesTransformerModel
>>> model = TimeSeriesTransformerModel.from_pretrained("huggingface/tst-base") >>> file = hf_hub_download(
... repo_id="kashif/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
>>> inputs = dict() ... )
>>> batch_size = 2 >>> batch = torch.load(file)
>>> cardinality = 5
>>> num_time_features = 10 >>> model = TimeSeriesTransformerModel.from_pretrained("huggingface/time-series-transformer-tourism-monthly")
>>> content_length = 8
>>> prediction_length = 2 >>> # during training, one provides both past and future values
>>> lags_sequence = [2, 3] >>> # as well as possible additional features
>>> past_length = context_length + max(lags_sequence) >>> outputs = model(
... past_values=batch["past_values"],
>>> # encoder inputs ... past_time_features=batch["past_time_features"],
>>> inputs["static_categorical_features"] = ids_tensor([batch_size, 1], cardinality) ... past_observed_mask=batch["past_observed_mask"],
>>> inputs["static_real_features"] = torch.randn([batch_size, 1]) ... static_categorical_features=batch["static_categorical_features"],
>>> inputs["past_time_features"] = torch.randn([batch_size, past_length, num_time_features]) ... static_real_features=batch["static_real_features"],
>>> inputs["past_values"] = torch.randn([batch_size, past_length]) ... future_values=batch["future_values"],
>>> inputs["past_observed_mask"] = torch.ones([batch_size, past_length]) ... future_time_features=batch["future_time_features"],
... )
>>> # decoder inputs
>>> inputs["future_time_features"] = torch.randn([batch_size, prediction_length, num_time_features]) >>> last_hidden_state = outputs.last_hidden_state
>>> inputs["future_values"] = torch.randn([batch_size, prediction_length])
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```""" ```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -1789,33 +1786,47 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel): ...@@ -1789,33 +1786,47 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
Examples: Examples:
```python ```python
>>> from transformers import TimeSeriesTransformerForPrediction >>> from huggingface_hub import hf_hub_download
>>> import torch >>> import torch
>>> from transformers import TimeSeriesTransformerForPrediction
>>> file = hf_hub_download(
... repo_id="kashif/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
... )
>>> batch = torch.load(file)
>>> model = TimeSeriesTransformerForPrediction.from_pretrained(
... "huggingface/time-series-transformer-tourism-monthly"
... )
>>> # during training, one provides both past and future values
>>> # as well as possible additional features
>>> outputs = model(
... past_values=batch["past_values"],
... past_time_features=batch["past_time_features"],
... past_observed_mask=batch["past_observed_mask"],
... static_categorical_features=batch["static_categorical_features"],
... static_real_features=batch["static_real_features"],
... future_values=batch["future_values"],
... future_time_features=batch["future_time_features"],
... )
>>> model = TimeSeriesTransformerForPrediction.from_pretrained("huggingface/tst-base")
>>> inputs = dict()
>>> batch_size = 2
>>> cardinality = 5
>>> num_time_features = 10
>>> content_length = 8
>>> prediction_length = 2
>>> lags_sequence = [2, 3]
>>> past_length = context_length + max(lags_sequence)
>>> # encoder inputs
>>> inputs["static_categorical_features"] = ids_tensor([batch_size, 1], cardinality)
>>> inputs["static_real_features"] = torch.randn([batch_size, 1])
>>> inputs["past_time_features"] = torch.randn([batch_size, past_length, num_time_features])
>>> inputs["past_values"] = torch.randn([batch_size, past_length])
>>> inputs["past_observed_mask"] = torch.ones([batch_size, past_length])
>>> # decoder inputs
>>> inputs["future_time_features"] = torch.randn([batch_size, prediction_length, num_time_features])
>>> inputs["future_values"] = torch.randn([batch_size, prediction_length])
>>> outputs = model(**inputs)
>>> loss = outputs.loss >>> loss = outputs.loss
>>> loss.backward()
>>> # during inference, one only provides past values
>>> # as well as possible additional features
>>> # the model autoregressively generates future values
>>> outputs = model.generate(
... past_values=batch["past_values"],
... past_time_features=batch["past_time_features"],
... past_observed_mask=batch["past_observed_mask"],
... static_categorical_features=batch["static_categorical_features"],
... static_real_features=batch["static_real_features"],
... future_time_features=batch["future_time_features"],
... )
>>> mean_prediction = outputs.sequences.mean(dim=1)
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
...@@ -104,6 +104,7 @@ src/transformers/models/segformer/modeling_tf_segformer.py ...@@ -104,6 +104,7 @@ src/transformers/models/segformer/modeling_tf_segformer.py
src/transformers/models/swin/configuration_swin.py src/transformers/models/swin/configuration_swin.py
src/transformers/models/swin/modeling_swin.py src/transformers/models/swin/modeling_swin.py
src/transformers/models/swinv2/configuration_swinv2.py src/transformers/models/swinv2/configuration_swinv2.py
src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py
src/transformers/models/trocr/modeling_trocr.py src/transformers/models/trocr/modeling_trocr.py
src/transformers/models/unispeech/configuration_unispeech.py src/transformers/models/unispeech/configuration_unispeech.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment