prithvi_geospatial_mae.py 14.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This is a demo script showing how to use the
PrithviGeospatialMAE model with vLLM
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa

Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa

The requirements for running this script are:
- Installing [terratorch, albumentations, rasterio] in your python environment
- downloading the model weights in a 'model' folder local to the script
  (temporary measure until the proper config.json file is uploaded to HF)
- download an input example image (India_900498_S2Hand.tif) and place it in
  the same folder with the script (or specify with the --data_file argument)

Run the example:
python prithvi_geospatial_mae.py

"""  # noqa: E501

import argparse
import datetime
import os
from typing import Union

import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule

from vllm import LLM

NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
OFFSET = 0
PERCENTILE = 99

model_config = """{
  "architectures": ["PrithviGeoSpatialMAE"],
  "num_classes": 0,
  "pretrained_cfg": {
    "task_args": {
      "task": "SemanticSegmentationTask",
      "model_factory": "EncoderDecoderFactory",
      "loss": "ce",
      "ignore_index": -1,
      "lr": 0.001,
      "freeze_backbone": false,
      "freeze_decoder": false,
      "plot_on_val": 10,
      "optimizer": "AdamW",
      "scheduler": "CosineAnnealingLR"
    },
    "model_args": {
      "backbone_pretrained": false,
      "backbone": "prithvi_eo_v2_300_tl",
      "decoder": "UperNetDecoder",
      "decoder_channels": 256,
      "decoder_scale_modules": true,
      "num_classes": 2,
      "rescale": true,
      "backbone_bands": [
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2"
      ],
      "head_dropout": 0.1,
      "necks": [
        {
          "name": "SelectIndices",
          "indices": [
            5,
            11,
            17,
            23
          ]
        },
        {
          "name": "ReshapeTokensToImage"
        }
      ]
    },
    "optimizer_params" : {
      "lr": 5.0e-05,
      "betas": [0.9, 0.999],
      "eps": [1.0e-08],
      "weight_decay": 0.05,
      "amsgrad": false,
      "maximize": false,
      "capturable": false,
      "differentiable": false
    },
    "scheduler_params" : {
        "T_max": 50,
        "eta_min": 0,
        "last_epoch": -1,
        "verbose": "deprecated"
    }
  },


  "torch_dtype": "float32"
}
"""

# Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF
with open(
    os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
) as config_file:
    config_file.write(model_config)

datamodule_config = {
    "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
    "batch_size": 16,
    "constant_scale": 0.0001,
    "data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
    "drop_last": True,
    "no_data_replace": 0.0,
    "no_label_replace": -1,
    "num_workers": 8,
    "test_transform": [
        albumentations.Resize(
            always_apply=False, height=448, interpolation=1, p=1, width=448
        ),
        albumentations.pytorch.ToTensorV2(
            transpose_mask=False, always_apply=True, p=1.0
        ),
    ],
}


class PrithviMAE:
    def __init__(self):
        print("Initializing PrithviMAE model")
        self.model = LLM(
            model=os.path.join(os.path.dirname(__file__), "./model"),
            skip_tokenizer_init=True,
            dtype="float32",
        )

    def run(self, input_data, location_coords):
        print("################ Running inference on vLLM ##############")
        # merge the inputs into one data structure
        mm_data = {
            "pixel_values": torch.empty(0) if input_data is None else input_data,
            "location_coords": torch.empty(0)
            if location_coords is None
            else location_coords,
        }

        prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}

        outputs = self.model.encode(prompt, use_tqdm=False)
        print("################ Inference done (it took seconds)  ##############")

        return outputs[0].outputs.data


def generate_datamodule():
    datamodule = Sen1Floods11NonGeoDataModule(
        data_root=datamodule_config["data_root"],
        batch_size=datamodule_config["batch_size"],
        num_workers=datamodule_config["num_workers"],
        bands=datamodule_config["bands"],
        drop_last=datamodule_config["drop_last"],
        test_transform=datamodule_config["test_transform"],
    )

    return datamodule


def process_channel_group(orig_img, channels):
    """
    Args:
        orig_img: torch.Tensor representing original image (reference)
                  with shape = (bands, H, W).
        channels: list of indices representing RGB channels.

    Returns:
        torch.Tensor with shape (num_channels, height, width) for original image
    """

    orig_img = orig_img[channels, ...]
    valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
    valid_mask[orig_img == NO_DATA_FLOAT] = False

    # Rescale (enhancing contrast)
    max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
    min_value = OFFSET

    orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)

    # No data as zeros
    orig_img[~valid_mask] = 0

    return orig_img


def read_geotiff(file_path: str):
    """Read all bands from *file_path* and return image + meta info.

    Args:
        file_path: path to image file.

    Returns:
        np.ndarray with shape (bands, height, width)
        meta info dict
    """

    with rasterio.open(file_path) as src:
        img = src.read()
        meta = src.meta
        try:
            coords = src.lnglat()
        except Exception:
            # Cannot read coords
            coords = None

    return img, meta, coords


def save_geotiff(image, output_path: str, meta: dict):
    """Save multi-band image in Geotiff file.

    Args:
        image: np.ndarray with shape (bands, height, width)
        output_path: path where to save the image
        meta: dict with meta info.
    """

    with rasterio.open(output_path, "w", **meta) as dest:
        for i in range(image.shape[0]):
            dest.write(image[i, :, :], i + 1)

    return


def _convert_np_uint8(float_image: torch.Tensor):
    image = float_image.numpy() * 255.0
    image = image.astype(dtype=np.uint8)

    return image


def load_example(
    file_paths: list[str],
    mean: list[float] = None,
    std: list[float] = None,
    indices: Union[list[int], None] = None,
):
    """Build an input example by loading images in *file_paths*.

    Args:
        file_paths: list of file paths .
        mean: list containing mean values for each band in the images
              in *file_paths*.
        std: list containing std values for each band in the images
             in *file_paths*.

    Returns:
        np.array containing created example
        list of meta info for each image in *file_paths*
    """

    imgs = []
    metas = []
    temporal_coords = []
    location_coords = []

    for file in file_paths:
        img, meta, coords = read_geotiff(file)

        # Rescaling (don't normalize on nodata)
        img = np.moveaxis(img, 0, -1)  # channels last for rescaling
        if indices is not None:
            img = img[..., indices]
        if mean is not None and std is not None:
            img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)

        imgs.append(img)
        metas.append(meta)
        if coords is not None:
            location_coords.append(coords)

        try:
            match = re.search(r"(\d{7,8}T\d{6})", file)
            if match:
                year = int(match.group(1)[:4])
                julian_day = match.group(1).split("T")[0][4:]
                if len(julian_day) == 3:
                    julian_day = int(julian_day)
                else:
                    julian_day = (
                        datetime.datetime.strptime(julian_day, "%m%d")
                        .timetuple()
                        .tm_yday
                    )
                temporal_coords.append([year, julian_day])
        except Exception as e:
            print(f"Could not extract timestamp for {file} ({e})")

    imgs = np.stack(imgs, axis=0)  # num_frames, H, W, C
    imgs = np.moveaxis(imgs, -1, 0).astype("float32")
    imgs = np.expand_dims(imgs, axis=0)  # add batch di

    return imgs, temporal_coords, location_coords, metas


def run_model(
    input_data,
    temporal_coords,
    location_coords,
    model,
    datamodule,
    img_size,
    lightning_model=None,
):
    # Reflect pad if not divisible by img_size
    original_h, original_w = input_data.shape[-2:]
    pad_h = (img_size - (original_h % img_size)) % img_size
    pad_w = (img_size - (original_w % img_size)) % img_size
    input_data = np.pad(
        input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
    )

    # Build sliding window
    batch_size = 1
    batch = torch.tensor(input_data, device="cpu")
    windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
    h1, w1 = windows.shape[3:5]
    windows = rearrange(
        windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
    )

    # Split into batches if number of windows > batch_size
    num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
    windows = torch.tensor_split(windows, num_batches, dim=0)

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    if temporal_coords:
        temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
    else:
        temporal_coords = None
    if location_coords:
        location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
    else:
        location_coords = None

    # Run model
    pred_imgs = []
    for x in windows:
        # Apply standardization
        x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0))
        x = datamodule.aug(x)["image"]

        with torch.no_grad():
            x = x.to(device)
            pred = model.run(x, location_coords=location_coords)
            if lightning_model:
                pred_lightning = lightning_model(
                    x, temporal_coords=temporal_coords, location_coords=location_coords
                )
                pred_lightning = pred_lightning.output.detach().cpu()
                if not torch.equal(pred, pred_lightning):
                    print("Inference output is not equal")
        y_hat = pred.argmax(dim=1)

        y_hat = torch.nn.functional.interpolate(
            y_hat.unsqueeze(1).float(), size=img_size, mode="nearest"
        )

        pred_imgs.append(y_hat)

    pred_imgs = torch.concat(pred_imgs, dim=0)

    # Build images from patches
    pred_imgs = rearrange(
        pred_imgs,
        "(b h1 w1) c h w -> b c (h1 h) (w1 w)",
        h=img_size,
        w=img_size,
        b=1,
        c=1,
        h1=h1,
        w1=w1,
    )

    # Cut padded area back to original size
    pred_imgs = pred_imgs[..., :original_h, :original_w]

    # Squeeze (batch size 1)
    pred_imgs = pred_imgs[0]

    return pred_imgs


def parse_args():
    parser = argparse.ArgumentParser("MAE run inference", add_help=False)

    parser.add_argument(
        "--data_file",
        type=str,
        default="./India_900498_S2Hand.tif",
        help="Path to the file.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Path to the directory where to save outputs.",
    )
    parser.add_argument(
        "--input_indices",
        default=[1, 2, 3, 8, 11, 12],
        type=int,
        nargs="+",
        help="0-based indices of the six Prithvi channels to be selected from the  "
        "input. By default selects [1,2,3,8,11,12] for S2L1C data.",
    )
    parser.add_argument(
        "--rgb_outputs",
        action="store_true",
        help="If present, output files will only contain RGB channels. "
        "Otherwise, all bands will be saved.",
    )


def main(
    data_file: str,
    output_dir: str,
    rgb_outputs: bool,
    input_indices: list[int] = None,
):
    os.makedirs(output_dir, exist_ok=True)

    # Load model ---------------------------------------------------------------

    model_obj = PrithviMAE()
    datamodule = generate_datamodule()
    img_size = 256  # Size of Sen1Floods11

    # Loading data -------------------------------------------------------------

    input_data, temporal_coords, location_coords, meta_data = load_example(
        file_paths=[data_file],
        indices=input_indices,
    )

    meta_data = meta_data[0]  # only one image

    if input_data.mean() > 1:
        input_data = input_data / 10000  # Convert to range 0-1

    # Running model ------------------------------------------------------------

    channels = [
        datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
    ]  # BGR -> RGB

    pred = run_model(
        input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
    )

    # Save pred
    meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
    pred_file = os.path.join(
        output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
    )
    save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)

    # Save image + pred
    meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)

    if input_data.mean() < 1:
        input_data = input_data * 10000  # Scale to 0-10000

    rgb_orig = process_channel_group(
        orig_img=torch.Tensor(input_data[0, :, 0, ...]),
        channels=channels,
    )

    pred[pred == 0.0] = np.nan
    img_pred = rgb_orig * 0.7 + pred * 0.3
    img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]

    img_pred_file = os.path.join(
        output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
    )
    save_geotiff(
        image=_convert_np_uint8(img_pred),
        output_path=img_pred_file,
        meta=meta_data,
    )

    # Save image rgb
    if rgb_outputs:
        rgb_file = os.path.join(
            output_dir,
            f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
        )
        save_geotiff(
            image=_convert_np_uint8(rgb_orig),
            output_path=rgb_file,
            meta=meta_data,
        )


if __name__ == "__main__":
    args = parse_args()

    main(**vars(args))